In [1]:
import torch
import torchvision

from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader, random_split

import os
import json
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

In [2]:
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset

class MSINTDataset(Dataset):
    def __init__(self, path, transform=None):
        self.path = path
        self.transform = transform

        self.len_dataset = 0
        self.data_list = []
        self.classes = []
        self.class_to_idx = {}

        for path_dir, dir_list, file_list in os.walk(path):
            if path_dir == path:
                self.classes = dir_list
                self.class_to_idx = {
                    cls_name: i for i, cls_name in enumerate(self.classes)
                }
                continue

            cls = path_dir.split('/')[-1][-1]

            for name_file in file_list:
                file_path = os.path.join(path_dir, name_file)
                self.data_list.append((file_path, self.class_to_idx[cls]))

            self.len_dataset += len(file_list)

    def __len__(self):
        return self.len_dataset

    def __getitem__(self, index):
        file_path, target = self.data_list[index]
        sample = np.array(Image.open(file_path))

        if self.transform is not None:
            sample = self.transform(sample)
        return sample, target

In [3]:
train_data = MSINTDataset('data/train')
test_data = MSINTDataset('data/test')

In [4]:
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)

In [11]:
import pygame
import numpy as np

color_bg = (10, 10, 10)
color_grid = (40, 40, 40)
color_die_next = (10, 10, 10) #(170, 170, 170)
color_alive_next = (255, 255, 255)


def update(screen, cells, size, with_progress=False):
    updated_cells = np.empty((cells.shape[0], cells.shape[1]))
    
    for row, col in np.ndindex(cells.shape):
        alive = np.sum(cells[row-1:row+2, col-1:col+2]) - cells[row, col]
        color = color_bg if cells[row, col] == 0 else color_alive_next
        
        if cells[row, col] == 1:
            if alive < 2 or alive > 3:
                updated_cells[row, col] = 0
                if with_progress:
                    color = color_die_next
            elif 2 <= alive <= 3:
                updated_cells[row, col] = 1
                if with_progress:
                    color = color_alive_next
        else:
            if alive == 3:
                updated_cells[row, col] = 1
                if with_progress:
                    color = color_alive_next
                    

        pygame.draw.rect(screen, color, (col * size, row * size, size - 1, size - 1))
        
    return updated_cells


def main():
    pygame.init()
    screen = pygame.display.set_mode((28*30, 28*23))
    
    cells = np.zeros((28, 28))
    screen.fill(color_grid)
    update(screen, cells, 10)
    
    pygame.display.flip()
    pygame.display.update()

    running = False
    
    while True:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                return
            elif event.type == pygame.KEYDOWN:
                if event.key == pygame.K_SPACE:
                    running = not running
                    update(screen, cells, 10)
                    pygame.display.update()
                    
            if pygame.mouse.get_pressed()[0]:
                pos = pygame.mouse.get_pos()
                cells[pos[1] // 10, pos[0] // 10] = 1
                update(screen, cells, 10)
                pygame.display.update()
                
        screen.fill(color_grid)
        
        clock = pygame.time.Clock()
        if running:
            cells = update(screen, cells, 10)
            pygame.display.update()

        # time.sleep(0.001)
        
        clock.tick(60)

if __name__ == '__main__':
    main()
    



In [None]:
pip install pygame