In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from dataset import PosterDataset
from tqdm import tqdm
import matplotlib.pyplot as plt
import os

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

dataset = PosterDataset(root_dir="images", transform=transform)
print(f"Number of images: {len(dataset)}")

dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
dataset_size = len(dataset)
train_size = int(0.8 * dataset_size)
val_size = int(0.1 * dataset_size)
test_size = dataset_size - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=4)

Number of samples: 10000
Number of batches: 2500


In [None]:
class BinaryCNN(nn.Module):
    def __init__(self):
        super(BinaryCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),  # for RGB images
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),  # reduces spatial dimensions by 2
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)
        )
        # After 3 poolings: 224 / 2 / 2 / 2 = 28
        self.classifier = nn.Sequential(
            nn.Linear(128 * 28 * 28, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 1)  # Binary output
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # flatten the tensor
        x = self.classifier(x)
        return x

In [None]:
model = BinaryCNN().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training settings
epochs = 200
best_val_acc = 0.0
best_epoch = 0
loss_history_train, loss_history_val = [], []
acc_history_train, acc_history_val = [], []

# Set up a folder to save the best model
save_dir = "models"
run_num = 1
run_folder = os.path.join(save_dir, f"run_{run_num}")
while os.path.exists(run_folder):
    run_num += 1
    run_folder = os.path.join(save_dir, f"run_{run_num}")
os.makedirs(run_folder, exist_ok=True)

for epoch in range(epochs):
    # --- Training phase ---
    model.train()
    train_loss = 0.0
    train_corrects = 0
    train_samples = 0
    train_bar = tqdm(train_loader, desc=f"Train Epoch {epoch+1}/{epochs}", leave=False)
    for images, labels in train_bar:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        # Ensure labels are the correct shape/type (float with shape [batch, 1])
        labels = labels.float().unsqueeze(1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)
        preds = (torch.sigmoid(outputs) >= 0.5)
        train_corrects += torch.sum(preds == labels.byte()).item()
        train_samples += images.size(0)
        train_bar.set_postfix(loss=loss.item())
        
    avg_train_loss = train_loss / train_samples
    train_acc = train_corrects / train_samples

    # --- Validation phase ---
    model.eval()
    val_loss = 0.0
    val_corrects = 0
    val_samples = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            labels = labels.float().unsqueeze(1)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            preds = (torch.sigmoid(outputs) >= 0.5)
            val_corrects += torch.sum(preds == labels.byte()).item()
            val_samples += images.size(0)
    avg_val_loss = val_loss / val_samples
    val_acc = val_corrects / val_samples

    # Save best model based on validation accuracy
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch
        torch.save(model.state_dict(), os.path.join(run_folder, "best_model.pth"))

    loss_history_train.append(avg_train_loss)
    loss_history_val.append(avg_val_loss)
    acc_history_train.append(train_acc)
    acc_history_val.append(val_acc)

    # Update loss and accuracy plots side by side
    clear_output(wait=True)
    fig, ax = plt.subplots(1, 2, figsize=(12, 5))
    # Loss plot
    ax[0].plot(loss_history_train, label="Train Loss")
    ax[0].plot(loss_history_val, label="Val Loss")
    ax[0].set_title("Loss over Epochs")
    ax[0].set_xlabel("Epoch")
    ax[0].set_ylabel("Loss")
    ax[0].legend()
    # Accuracy plot
    ax[1].plot(acc_history_train, label="Train Accuracy")
    ax[1].plot(acc_history_val, label="Val Accuracy")
    ax[1].set_title("Accuracy over Epochs")
    ax[1].set_xlabel("Epoch")
    ax[1].set_ylabel("Accuracy")
    ax[1].legend()
    plt.show()

    # Print epoch results
    print(f"Epoch {epoch+1}/{epochs}")
    print(f"  Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Val   Loss: {avg_val_loss:.4f}, Val   Acc: {val_acc:.4f}")
    print(f"  Best Val Acc: {best_val_acc:.4f} at Epoch {best_epoch+1}")
    print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")

                                                                           

Epoch 1/5 - Loss: 0.5329 - Accuracy: 0.7365


                                                                            

Epoch 2/5 - Loss: 0.4258 - Accuracy: 0.8092


                                                                            

Epoch 3/5 - Loss: 0.3469 - Accuracy: 0.8479


                                                                             

Epoch 4/5 - Loss: 0.2461 - Accuracy: 0.8991


                                                                             

Epoch 5/5 - Loss: 0.1343 - Accuracy: 0.9465


