In [None]:
# Setup: GPU check and Drive mount
import torch
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repository
!git clone https://github.com/emanusilva2003/Assignment_2_ComputerVision

import sys
repo_path = '/content/Assignment_2_ComputerVision'
if repo_path not in sys.path:
    sys.path.insert(0, repo_path)
print(f'Repo path: {repo_path}')

In [None]:
# Imports
import os
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
from SUIM.Pytorch.models.suim_net import SUIM_Net
from SUIM.Pytorch.utils.data_utils import get_suim_dataloader

In [None]:
# Configuration
BASE = 'RSB'  # 'VGG' or 'RSB'
BATCH_SIZE = 8
NUM_EPOCHS = 50
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Paths (ajusta o caminho do Drive para o teu dataset)
TRAIN_DIR = '/content/drive/MyDrive/VCOM/SUIM/train_val'
CKPT_DIR = '/content/drive/MyDrive/VCOM/checkpoints'
os.makedirs(CKPT_DIR, exist_ok=True)

print(f"{BASE} | Batch: {BATCH_SIZE} | Epochs: {NUM_EPOCHS} | Device: {DEVICE}")

In [None]:
# Data and Model
train_loader = get_suim_dataloader(
    train_dir=TRAIN_DIR,
    batch_size=BATCH_SIZE,
    target_size=(320, 256),
    augmentation=True,
    augmentation_params={'rotation_range': 0.2, 'width_shift_range': 0.05,
                        'height_shift_range': 0.05, 'shear_range': 0.05,
                        'zoom_range': 0.05, 'horizontal_flip': True},
    num_workers=2,
    shuffle=True
)

model = SUIM_Net(base=BASE, n_classes=5, pretrained=True).to(DEVICE)
total, trainable = model.count_parameters()
print(f"Samples: {len(train_loader.dataset)} | Params: {total:,}")

In [None]:
# Training setup
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

In [None]:
# Training loop
history = {'loss': [], 'lr': []}
best_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    model.train()
    epoch_loss = 0.0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
    for images, masks in pbar:
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = epoch_loss / len(train_loader)
    history['loss'].append(avg_loss)
    history['lr'].append(optimizer.param_groups[0]['lr'])
    
    print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f} | LR: {history['lr'][-1]:.2e}")
    scheduler.step(avg_loss)
    
    # Save best
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
        }, os.path.join(CKPT_DIR, f"suimnet_{BASE.lower()}_best.pth"))
        print(f"  ✓ Best: {best_loss:.4f}")
    
    # Checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, os.path.join(CKPT_DIR, f"suimnet_{BASE.lower()}_epoch_{epoch+1}.pth"))
        print(f"  ✓ Checkpoint saved")

print(f"\nCompleted! Best loss: {best_loss:.4f}")

In [None]:
# Plot results
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(history['loss'], 'b-', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.grid(alpha=0.3)

ax2.plot(history['lr'], 'r-', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Learning Rate')
ax2.set_title('LR Schedule')
ax2.set_yscale('log')
ax2.grid(alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Initial: {history['loss'][0]:.4f} | Final: {history['loss'][-1]:.4f} | Best: {best_loss:.4f}")