In [1]:
# import constants
import sys
import os
sys.path.append(os.path.abspath('..'))  # Add parent directory to path
from src.utils.constants import DATA_DIR, MNIST_DIR, SUPERVISED_DIR, WEAKLY_SUPERVISED_DIR
from data.dataset import SupervisedMNISTDataset, WeaklySupervisedMNISTDataset
from src.models.digit_classifier import SupervisedModel

import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
# Construct train, val, test dataloaders
train_loader = DataLoader(SupervisedMNISTDataset(split='train'), batch_size=128, shuffle=True)
val_loader = DataLoader(SupervisedMNISTDataset(split='val'), batch_size=128, shuffle=False)
test_loader = DataLoader(SupervisedMNISTDataset(split='test'), batch_size=128, shuffle=False)

In [3]:
weakly_supervised_train_loader = DataLoader(WeaklySupervisedMNISTDataset(split='train'), batch_size=128, shuffle=True)
weakly_supervised_val_loader = DataLoader(WeaklySupervisedMNISTDataset(split='val'), batch_size=128, shuffle=False)
weakly_supervised_test_loader = DataLoader(WeaklySupervisedMNISTDataset(split='test'), batch_size=128, shuffle=False)

In [4]:
for inputs, labels in weakly_supervised_train_loader:
    print(inputs.shape, labels.shape)
    break

torch.Size([128, 2, 1, 28, 28]) torch.Size([128, 1])


In [None]:
# Train a SupervisedModel
model = SupervisedModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training parameters
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Training loop
train_losses = []
val_losses = []
train_accs = []
val_accs = []

for epoch in range(num_epochs):
    # Training phase
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    train_loss = running_loss / len(train_loader)
    train_acc = 100. * correct / total
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    
    # Validation phase
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    val_loss = running_loss / len(val_loader)
    val_acc = 100. * correct / total
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    
    print(f'Epoch {epoch+1}/{num_epochs}:')
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    print('-' * 50)

# Plot training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Acc')
plt.plot(val_accs, label='Val Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Training and Validation Accuracy')

plt.tight_layout()
plt.show()