<a href="https://colab.research.google.com/github/ecflorui/genesys-lab/blob/main/resnet18_experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os, glob
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models

# -------- Paths (Google Drive) --------
v5_path = '/content/drive/MyDrive/v5'
v6_path = '/content/drive/MyDrive/v6 (1)'
save_path = '/content/drive/MyDrive/best_model_resnet18.pth'

# -------- Hyperparameters --------
BATCH_SIZE = 32
EPOCHS = 10
PATIENCE = 5

# -------- Dataset Class --------
class ImageLabelDataset(Dataset):
    def __init__(self, root_dirs, transform=None):
        self.samples = []
        self.transform = transform
        for root in root_dirs:
            print(f"Checking files in: {root}")
            for image_file in glob.glob(os.path.join(root, '*.*')):
                if image_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    npy_file = os.path.splitext(image_file)[0] + '.npy'
                    if os.path.exists(npy_file):
                        self.samples.append((image_file, npy_file))
                    else:
                        print(f"Missing .npy for: {image_file}")
        print(f"Total valid samples found: {len(self.samples)}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label_path = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        label_vec = np.load(label_path)
        label = int(np.argmax(label_vec))
        if self.transform:
            image = self.transform(image)
        return image, label

# -------- Transforms --------
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# -------- Load Data --------
full_dataset = ImageLabelDataset([v5_path, v6_path])

if len(full_dataset) == 0:
    raise ValueError("Dataset is empty. Check your Drive paths and file formats.")

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])

train_ds.dataset.transform = train_transform
val_ds.dataset.transform = val_transform

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)

# -------- Model --------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
base_model = models.resnet18(pretrained=True)
num_features = base_model.fc.in_features
base_model.fc = nn.Sequential(
    nn.Dropout(0.4),
    nn.Linear(num_features, 4)
)
model = base_model.to(device)

# -------- Loss and Optimizer --------
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# -------- Evaluation Function --------
def evaluate(loader):
    model.eval()
    correct, total, loss_sum = 0, 0, 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss_sum += loss.item() * imgs.size(0)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return loss_sum / total, correct / total

# -------- Training Loop --------
best_val_acc = 0.0
patience_counter = 0

for epoch in range(EPOCHS):
    model.train()
    for imgs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}'):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    train_loss, train_acc = evaluate(train_loader)
    val_loss, val_acc = evaluate(val_loader)
    print(f'Epoch {epoch+1}: Train Loss={train_loss:.4f}, Acc={train_acc:.4f} | Val Loss={val_loss:.4f}, Acc={val_acc:.4f}')

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        torch.save(model.state_dict(), save_path)
        print(f"Saved new best model with Val Acc: {val_acc:.4f}")
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print("Early stopping triggered.")
            break


Checking files in: /content/drive/MyDrive/v5
Missing .npy for: /content/drive/MyDrive/v5/20250721_164649.jpg
Missing .npy for: /content/drive/MyDrive/v5/20250721_181726.jpg
Checking files in: /content/drive/MyDrive/v6 (1)
Missing .npy for: /content/drive/MyDrive/v6 (1)/20250722_163126.jpg
Missing .npy for: /content/drive/MyDrive/v6 (1)/20250722_171303.jpg
Missing .npy for: /content/drive/MyDrive/v6 (1)/20250722_174704.jpg
Missing .npy for: /content/drive/MyDrive/v6 (1)/20250722_183244.jpg
Missing .npy for: /content/drive/MyDrive/v6 (1)/20250722_185623.jpg
Missing .npy for: /content/drive/MyDrive/v6 (1)/20250722_190931.jpg
Total valid samples found: 587


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 180MB/s]
Epoch 1: 100%|██████████| 15/15 [03:44<00:00, 15.00s/it]


Epoch 1: Train Loss=1.0246, Acc=0.6162 | Val Loss=1.1968, Acc=0.5932
Saved new best model with Val Acc: 0.5932


Epoch 2: 100%|██████████| 15/15 [01:19<00:00,  5.31s/it]


Epoch 2: Train Loss=0.6644, Acc=0.7591 | Val Loss=1.1246, Acc=0.6186
Saved new best model with Val Acc: 0.6186


Epoch 3: 100%|██████████| 15/15 [01:20<00:00,  5.37s/it]


Epoch 3: Train Loss=0.3109, Acc=0.9360 | Val Loss=1.1530, Acc=0.6525
Saved new best model with Val Acc: 0.6525


Epoch 4: 100%|██████████| 15/15 [01:19<00:00,  5.32s/it]


Epoch 4: Train Loss=0.1257, Acc=0.9979 | Val Loss=1.2702, Acc=0.6017


Epoch 5: 100%|██████████| 15/15 [01:18<00:00,  5.27s/it]


Epoch 5: Train Loss=0.0512, Acc=1.0000 | Val Loss=1.4214, Acc=0.5763


Epoch 6: 100%|██████████| 15/15 [01:19<00:00,  5.29s/it]


Epoch 6: Train Loss=0.0284, Acc=0.9979 | Val Loss=1.5281, Acc=0.5254


Epoch 7: 100%|██████████| 15/15 [01:20<00:00,  5.34s/it]


Epoch 7: Train Loss=0.0240, Acc=0.9979 | Val Loss=1.6580, Acc=0.5593


Epoch 8:  80%|████████  | 12/15 [01:04<00:16,  5.41s/it]