In [1]:
import os

import kornia.augmentation as K
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torchvision.transforms.functional as TF
import numpy as np

In [2]:
IMAGE_SIZE = 48
NUM_CLASSES = 21
BATCH_SIZE = 32

In [3]:
def transform(image):
    # Apply random horizontal flip with probability 0.5
    image = K.RandomHorizontalFlip(p=0.5)(image)
    # Apply random affine transformation with rotation angle range [-15,15] degrees
    image = K.RandomAffine(degrees=(-15, 15))(image)
    # Apply random color jitter with brightness range [0.8,1.2]
    image = K.ColorJitter(brightness=(0.8, 1.2))(image)

    return image

In [4]:
def convert_label_to_int(labels):
    label_dict = {}
    rev_label_dict = {}

    for idx, label in enumerate(labels):
        label_dict[label] = idx
        rev_label_dict[idx] = label

    return label_dict, rev_label_dict

In [25]:
class ImageClassificationDataModule(pl.LightningDataModule):
    def __init__(self, data_dirs, label2idx):
        super().__init__()
        self.data_dirs = data_dirs
        self.label2idx = label2idx

    def setup(self, stage=None):
        self.images = []
        self.labels = []
        for dir in self.data_dirs:
            for filename in os.listdir(dir):
                # Read webp image using Pillow
                image = Image.open(os.path.join(dir, filename))
                # Resize image to IMAGE_SIZE x IMAGE_SIZE
                image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
                # Convert image to tensor and normalize iter
                image = TF.to_grayscale(image)
                image = TF.to_tensor(image)
                mean, std = image.mean([1,2]), image.std([1,2])
                image = TF.normalize(image, mean=mean, std=std)
                self.images.append(image)
                # Use filename as label (without extension)
                label = os.path.splitext(filename)[0]
                # Convert label to integer (labels should go from 0 to NUM_CLASSES - 1)
                label = self.label2idx[label]
                self.labels.append(label)
        # Split images and labels into train and val sets
        split_idx = int(len(self.images) * 0.6)
        self.train_images = self.images[:split_idx]
        self.train_labels = self.labels[:split_idx]
        self.val_images = self.images[split_idx:]
        self.val_labels = self.labels[split_idx:]

    def train_dataloader(self):
        # Create a dataset from train images and labels
        train_dataset = TensorDataset(torch.stack(self.train_images), torch.tensor(self.train_labels))
        train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)
        return train_dataloader

    def val_dataloader(self):
        # Create a TensorDataset from val images and labels
        val_dataset = TensorDataset(torch.stack(self.val_images), torch.tensor(self.val_labels))
        # Create a dataloader from val dataset with shuffle=False and batch_size=BATCH_SIZE
        val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=BATCH_SIZE)

        return val_dataloader

In [31]:
class ImageClassificationModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.pool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool2 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = torch.nn.Linear(32 * IMAGE_SIZE // 4 * IMAGE_SIZE // 4, 128)
        self.fc2 = torch.nn.Linear(128, NUM_CLASSES)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 32 * IMAGE_SIZE // 4 * IMAGE_SIZE // 4)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        # Get images and labels from batch
        images, labels = batch
        # Get logits from model
        logits = self(images)
        # Compute cross entropy loss
        loss = F.cross_entropy(logits, labels)
        # Log loss to tensorboard
        self.log('train_loss', loss)

        return loss

    def validation_step(self, batch, batch_idx):
        # Get images and labels from batch
        images, labels = batch
        # Get logits from model
        logits = self(images)
        # Compute cross entropy loss
        loss = F.cross_entropy(logits, labels)
        # Log loss to tensorboard
        self.log('val_loss', loss)

        return loss

    def configure_optimizers(self):
        # Define optimizer (e.g., Adam)
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)

        return optimizer


In [27]:
# Data augmentation block
for filename in os.listdir('images'):
    image = Image.open(os.path.join('images', filename))
    image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
    image = TF.to_tensor(image)[0:3,:,:]
    mean, std = image.mean([1,2]), image.std([1,2])
    image = TF.normalize(image, mean=mean, std=std)

    transformed_image = transform(image)
    pil_image = TF.to_pil_image(transformed_image.squeeze(0))
    pil_image.save(f'processed2/{os.path.splitext(filename)[0]}.png')

In [33]:
labels = []
images = []
for filename in os.listdir('processed'):
    # Read webp image using Pillow
    image = Image.open(os.path.join('processed', filename))
    image = TF.to_tensor(image)
    label = os.path.splitext(filename)[0]
    # Convert label to integer (labels should go from 0 to NUM_CLASSES - 1)
    labels.append(label)

label2id, id2label = convert_label_to_int(labels)

In [34]:
print('ciao')
data_module = ImageClassificationDataModule(['images', 'processed', 'processed2', 'test'], label2id)
model = ImageClassificationModel()
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, data_module)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type      | Params
------------------------------------
0 | conv1 | Conv2d    | 160   
1 | pool1 | MaxPool2d | 0     
2 | conv2 | Conv2d    | 4.6 K 
3 | pool2 | MaxPool2d | 0     
4 | fc1   | Linear    | 589 K 
5 | fc2   | Linear    | 2.7 K 
------------------------------------
597 K     Trainable params
0         Non-trainable params
597 K     Total params
2.390     Total estimated model params size (MB)


ciao


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [30]:
test_image = Image.open('test/FulliconStatusEffects_broken.png')
test_image = TF.center_crop(test_image, [IMAGE_SIZE,IMAGE_SIZE])
test_image.save('test/cocco.png')
test_image = TF.to_grayscale(test_image)
test_image.save('test/cocco.png')
test_image = TF.to_tensor(test_image)
result = model(test_image)
res_idx = np.argmax(result.detach())
id2label[res_idx.item()]

'FulliconStatusEffects_oblivious'