In [2]:
import torch
torch.__version__
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from pathlib import Path
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim


In [3]:
#!pip install scikit-learn
#!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

In [4]:

class SimpleCNN(nn.Module):
    """
    Improved CNN:
    - ResNet18 backbone (fully fine-tuned)
    - 256-dim embedding for anomaly detection
    - 6-class classifier head
    """
    def __init__(self, num_classes: int = 6, embed_dim: int = 256):
        super().__init__()

        backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

        self.feature_extractor = nn.Sequential(*list(backbone.children())[:-1])
        self.embedding = nn.Sequential(
            nn.Linear(512, embed_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.3),
        )

        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward_features(self, x):
        """
        Return embedding vector (B, embed_dim).
        Use this for anomaly detection later.
        """
        x = self.feature_extractor(x)   
        x = torch.flatten(x, 1)       
        x = self.embedding(x)       
        return x

    def forward(self, x):
        feats = self.forward_features(x)     
        logits = self.classifier(feats)      
        return logits


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

train_dir = Path("../data/split/train")
val_dir   = Path("../data/split/val")

# ---- Stronger augmentations for better generalization ----
train_tfms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.RandomRotation(20),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std =[0.229, 0.224, 0.225],
    ),
])

val_tfms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std =[0.229, 0.224, 0.225],
    ),
])

train_dataset = datasets.ImageFolder(train_dir, transform=train_tfms)
val_dataset   = datasets.ImageFolder(val_dir,   transform=val_tfms)

batch_size = 32  # reduce to 16 if RAM is angry
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False)

print("Train size:", len(train_dataset))
print("Val size  :", len(val_dataset))


Device: cpu
Train size: 1767
Val size  : 377


In [6]:

# ---- Model (full fine-tuning, nothing frozen) ----
model = SimpleCNN(num_classes=6, embed_dim=256).to(device)

# All params trainable now
optimizer = optim.Adam(
    model.parameters(),
    lr=1e-4,          # lower LR because we fine-tune whole backbone
    weight_decay=1e-4
)

criterion = nn.CrossEntropyLoss()

num_epochs = 8        
best_val_acc = 0.0
best_state_dict = None

for epoch in range(1, num_epochs + 1):
    # ---- Train ----
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for imgs, labels in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        _, preds = outputs.max(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    train_loss = running_loss / total
    train_acc  = correct / total

    # ---- Validation ----
    model.eval()
    val_loss_total = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)

            outputs = model(imgs)
            loss = criterion(outputs, labels)

            val_loss_total += loss.item() * imgs.size(0)
            _, preds = outputs.max(1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss = val_loss_total / val_total
    val_acc  = val_correct / val_total

    print(
        f"Epoch [{epoch}/{num_epochs}] "
        f"Train: Loss={train_loss:.4f} Acc={train_acc:.4f} | "
        f"Val: Loss={val_loss:.4f} Acc={val_acc:.4f}"
    )

    # Keep best weights IN MEMORY ONLY
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}

print("\n==== DONE ====")
print("Best validation accuracy:", best_val_acc)


Epoch [1/8] Train: Loss=1.1660 Acc=0.5823 | Val: Loss=0.6147 Acc=0.7639
Epoch [2/8] Train: Loss=0.5919 Acc=0.8087 | Val: Loss=0.4132 Acc=0.8647
Epoch [3/8] Train: Loss=0.3965 Acc=0.8721 | Val: Loss=0.4806 Acc=0.8223
Epoch [4/8] Train: Loss=0.3088 Acc=0.8993 | Val: Loss=0.2937 Acc=0.8992
Epoch [5/8] Train: Loss=0.2410 Acc=0.9208 | Val: Loss=0.3348 Acc=0.8912
Epoch [6/8] Train: Loss=0.2211 Acc=0.9304 | Val: Loss=0.2757 Acc=0.9098
Epoch [7/8] Train: Loss=0.1541 Acc=0.9502 | Val: Loss=0.2797 Acc=0.9072
Epoch [8/8] Train: Loss=0.1176 Acc=0.9626 | Val: Loss=0.2705 Acc=0.8992

==== DONE ====
Best validation accuracy: 0.9098143236074271


In [8]:
from pathlib import Path
import json
import torch
import os

print("Notebook running from:", os.getcwd())

# go one level up from notebooks/ â†’ project root
save_dir = Path("../models/classifier")
save_dir.mkdir(parents=True, exist_ok=True)

# 2) Save model weights (current model)
model_path = save_dir / "simple_cnn.pth"
torch.save(model.state_dict(), model_path)
print(f"Saved model weights to: {model_path}")

# 3) Save class names
class_names = train_dataset.classes
classes_path = save_dir / "classes.json"
with open(classes_path, "w") as f:
    json.dump(class_names, f)

print(f"Saved class names to: {classes_path}")
print("Classes:", class_names)


Notebook running from: C:\Users\sadek\OneDrive\Desktop\DSAI4101-project\notebooks
Saved model weights to: ..\models\classifier\simple_cnn.pth
Saved class names to: ..\models\classifier\classes.json
Classes: ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
