In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import os
from PIL import Image
import json
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, recall_score, accuracy_score
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
from sklearn.metrics import recall_score
from tqdm import tqdm
from torchvision.utils import save_image

In [31]:
class GeneratedImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.label_names = {'Live': 0, 'Print': 1, 'Papercut': 2, 'Replay': 3, '3D': 4}
        
        for label_name, label in self.label_names.items():
            label_dir = os.path.join(self.root_dir, label_name)
            for img_name in os.listdir(label_dir):
                self.images.append(os.path.join(label_dir, img_name))
                self.labels.append(label)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label


In [32]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 16 * 16, 128)  
        self.fc2 = nn.Linear(128, 5)  

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(x.size(0), -1)  
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [33]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # Resize the images
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the images
])

pretrain_dataset = GeneratedImageDataset(root_dir='generated_data/', transform=transform)
pretrain_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [34]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(pretrain_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 50 == 9:
            print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 10:.4f}')
            running_loss = 0.0

print('Finished Training')

[1, 10] loss: 1.5177
[1, 60] loss: 1.9713
[1, 110] loss: 0.6261
[2, 10] loss: 0.1046
[2, 60] loss: 0.3117
[2, 110] loss: 0.2325
[3, 10] loss: 0.0239
[3, 60] loss: 0.1303
[3, 110] loss: 0.2046
[4, 10] loss: 0.0276
[4, 60] loss: 0.0358
[4, 110] loss: 0.1060
[5, 10] loss: 0.0074
[5, 60] loss: 0.0338
[5, 110] loss: 0.0351
[6, 10] loss: 0.0023
[6, 60] loss: 0.0152
[6, 110] loss: 0.0221
[7, 10] loss: 0.0090
[7, 60] loss: 0.1026
[7, 110] loss: 0.0220
[8, 10] loss: 0.0058
[8, 60] loss: 0.0148
[8, 110] loss: 0.0797
[9, 10] loss: 0.0596
[9, 60] loss: 0.0503
[9, 110] loss: 0.0158
[10, 10] loss: 0.0026
[10, 60] loss: 0.0291
[10, 110] loss: 0.0064
Finished Training


In [35]:
full_dataset = GeneratedImageDataset(root_dir='multiclass_data/', transform=transform)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

dataloaders = {'train': train_loader, 'val': val_loader}

In [36]:
num_epochs = 5
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    print('-' * 10)
    
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train() 
        else:
            model.eval()   

        running_loss = 0.0
        all_preds = []
        all_labels = []

        for inputs, labels in dataloaders[phase]:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()


            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

        epoch_loss = running_loss / len(dataloaders[phase].dataset)
        epoch_accuracy = accuracy_score(all_labels, all_preds)
        epoch_precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
        epoch_recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)
        epoch_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)

        print(f'{phase.upper()} Loss: {epoch_loss:.4f} Acc: {epoch_accuracy:.4f} Precision: {epoch_precision:.4f} Recall: {epoch_recall:.4f} F1: {epoch_f1:.4f}')

print('Training and evaluation complete')

Epoch 1/5
----------
TRAIN Loss: 0.4000 Acc: 0.8590 Precision: 0.8456 Recall: 0.8306 F1: 0.8371
VAL Loss: 0.1937 Acc: 0.9365 Precision: 0.9321 Recall: 0.9250 F1: 0.9280
Epoch 2/5
----------
TRAIN Loss: 0.1394 Acc: 0.9539 Precision: 0.9498 Recall: 0.9491 F1: 0.9494
VAL Loss: 0.1479 Acc: 0.9503 Precision: 0.9536 Recall: 0.9412 F1: 0.9465
Epoch 3/5
----------
TRAIN Loss: 0.0750 Acc: 0.9760 Precision: 0.9750 Recall: 0.9754 F1: 0.9752
VAL Loss: 0.1426 Acc: 0.9600 Precision: 0.9598 Recall: 0.9548 F1: 0.9572
Epoch 4/5
----------
TRAIN Loss: 0.0474 Acc: 0.9841 Precision: 0.9829 Recall: 0.9830 F1: 0.9830
VAL Loss: 0.1617 Acc: 0.9554 Precision: 0.9591 Recall: 0.9515 F1: 0.9550
Epoch 5/5
----------
TRAIN Loss: 0.0327 Acc: 0.9890 Precision: 0.9889 Recall: 0.9890 F1: 0.9890
VAL Loss: 0.1804 Acc: 0.9591 Precision: 0.9591 Recall: 0.9560 F1: 0.9572
Training and evaluation complete


In [37]:
torch.save(model, 'gan_simple_cnn.pth')