In [None]:
import time
import numpy as np
from PIL import Image
import random
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from EvolutionSimulation.python.neuralNetworks.ViT import tokenizer


dataset2 = []
files = [
        "/home/allan/nvim/projects/EvolutionSimulation/EvolutionSimulation/data/rawData/full_numpy_bitmap_lion.npy", 
         #"/home/allan/nvim/projects/EvolutionSimulation/EvolutionSimulation/data/rawData/full_numpy_bitmap_crocodile.npy", 
         #"/home/allan/nvim/projects/EvolutionSimulation/EvolutionSimulation/data/rawData/full_numpy_bitmap_dragon.npy", 
         "/home/allan/nvim/projects/EvolutionSimulation/EvolutionSimulation/data/rawData/full_numpy_bitmap_sheep.npy", 
         #"/home/allan/nvim/projects/EvolutionSimulation/EvolutionSimulation/data/rawData/full_numpy_bitmap_duck.npy"
        ]

# Assigning new labels
class_labels = {
    "/home/allan/nvim/projects/EvolutionSimulation/EvolutionSimulation/data/rawData/full_numpy_bitmap_lion.npy": 1,
    #"/home/allan/nvim/projects/EvolutionSimulation/EvolutionSimulation/data/rawData/full_numpy_bitmap_crocodile.npy": 1,
    #"/home/allan/nvim/projects/EvolutionSimulation/EvolutionSimulation/data/rawData/full_numpy_bitmap_dragon.npy": 1,
    "/home/allan/nvim/projects/EvolutionSimulation/EvolutionSimulation/data/rawData/full_numpy_bitmap_sheep.npy": 0,
    #"/home/allan/nvim/projects/EvolutionSimulation/EvolutionSimulation/data/rawData/full_numpy_bitmap_duck.npy": 0
}

for filename in files:
    images = np.load(rf"{filename}")
    print(f"Loaded {filename} with shape: {images.shape}")

    t_0 = time.perf_counter()
    count = 0

    # Loop through each image in the file
    for i in range(len(images)):
        # Only process the first 1000 images from each class

        image = images[i]  # Provides (728,) array
        reshape = image.reshape(28, 28)  # Reshapes to (28, 28) numpy array
        image = Image.fromarray(reshape)
        grayscale_image = image.convert("L")

        # Assign the label based on the class of the file
        label = class_labels[filename]  # Get the label for the class

        data = {
            'image': grayscale_image,  # The image tensor
            'label': label  # The corresponding label (class)
        }

        dataset2.append(data)
        count += 1

    t_1 = time.perf_counter()
    print(f"Successfully processed {filename} in {t_1 - t_0:.2f} seconds")


random.shuffle(dataset2)


class MyCustomDataset(Dataset):
    def __init__(self):
        self.dataset = dataset2

        self.transform = T.ToTensor()

        self.captions = {
            1: "a drawing of a lion",
            1: "a drawing of a crocodile",
            1: "a drawing of a dragon",
            0: "a drawing of a sheep",
            0: "a drawing of a duck"
        }

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        img = self.dataset[i]["image"]
        img = self.transform(img)

        cap, mask = tokenizer(self.captions[self.dataset[i]["label"]])

        mask = mask.repeat(len(mask), 1)

        return {"image": img, "caption": cap, "mask": mask}



batch_size = 128
print("Loading Data...")
train_set = MyCustomDataset()
test_set = MyCustomDataset()

train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_set, shuffle=True, batch_size=batch_size)
print("Data Loaded")


Loaded /home/allan/nvim/projects/EvolutionSimulation/EvolutionSimulation/data/rawData/full_numpy_bitmap_lion.npy with shape: (120949, 784)
Successfully processed /home/allan/nvim/projects/EvolutionSimulation/EvolutionSimulation/data/rawData/full_numpy_bitmap_lion.npy in 0.64 seconds
Loaded /home/allan/nvim/projects/EvolutionSimulation/EvolutionSimulation/data/rawData/full_numpy_bitmap_sheep.npy with shape: (126121, 784)
Successfully processed /home/allan/nvim/projects/EvolutionSimulation/EvolutionSimulation/data/rawData/full_numpy_bitmap_sheep.npy in 1.45 seconds
Loading Data...
Data Loaded


In [6]:
train_set[0]

{'image': tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1608, 0.4392,
           0.3059, 0.1412, 0.1098, 0.4000, 0.2471, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0078, 0.3333, 0.7451, 0.9961, 1.0000,
           1.0000, 0.8902, 0.7490, 1.0000, 1.0000, 0.7686, 0.3490, 0.8745,
           0.9098, 0.2706, 0.0000,