# Imports

In [None]:
import json
import os

import copy
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from torchvision.transforms import ToTensor
from torchvision import transforms, models

from torchinfo import summary


# Checking device

In [None]:
print("CUDA available?", torch.cuda.is_available())
print("Device name:", torch.cuda.get_device_name(0))
print("Current device:", torch.cuda.current_device())

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

print("Using:", device)

# Experiment setup

In [None]:
setup = {
    "experiment": "resnet50-700-SGD-CELoss",
    "num_classes": 700,
    "batch_size": 128,
    "num_workers": 1,
    "criterion": nn.CrossEntropyLoss(),
    "lr": 1e-3,
    "weight_decay": 1e-4,
    "momentum": 0.9,
    "max_epochs": 10
}

folder = f"./experiments/{setup['experiment']}"
os.makedirs(folder, exist_ok=True)
file_path = os.path.join(folder, 'setup.txt')

with open(file_path, 'w', encoding='utf-8') as f:
    json.dump(setup, f, indent=4, ensure_ascii=False, default=str)

print(f"Saved in: {file_path}")

tensorboard_path = f'./experiments/{setup["experiment"]}/tensorboard/'
models_path = f"./experiments/{setup['experiment']}/models/"

os.makedirs(os.path.join(tensorboard_path), exist_ok=True)
os.makedirs(os.path.join(models_path), exist_ok=True)

# Tensorboard functions

In [None]:
def plot_net_attributes(epoch, net, writer):
    layers = list(net.modules())

    layer_id = 1
    for layer in layers:
        if isinstance(layer, nn.Linear) :
            writer.add_histogram(f'Bias/linear-{layer_id}', layer.bias, epoch )
            writer.add_histogram(f'Weight/linear-{layer_id}', layer.weight, epoch )
            writer.add_histogram(f'Grad/linear-{layer_id}', layer.weight.grad, epoch )
            layer_id += 1

# Augmentations

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((224,224)),
    
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

## View

In [None]:
train_transform_view = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

img_dir  = "../dataset/nabirds/versions/1/images/0295"

img_files = [f for f in os.listdir(img_dir) if f.endswith(".jpg")]

N = 4
img_files = img_files[:N]

cols = 2
rows = N // cols + int(N % cols != 0)

fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 4*rows))

for idx, img_fname in enumerate(img_files):
    pil = Image.open(os.path.join(img_dir, img_fname)).convert("RGB")

    out = train_transform_view(pil)
    out = out.squeeze(0).permute(1,2,0).cpu().numpy()

    r, c = divmod(idx, cols)
    ax = axes[r, c]
    ax.imshow(out)
    ax.set_title(f"Image {idx+1}")
    ax.axis("off")
    
plt.tight_layout()
plt.show()

# NaBird

## Dataset

In [None]:
class NABirdsDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None, remap_labels=True):
        """
        Dataset para as N categorias que realmente aparecem.
        Se remap_labels=True, vai comprimir os labels para 0..(N-1).
        """
        self.root      = root_dir
        self.transform = transform

        # 1) carrega image_id → caminho
        self.id2path = {}
        with open(os.path.join(root_dir, "images.txt"), "r") as f:
            for line in f:
                parts = line.strip().split()
                if len(parts)!=2: continue
                img_id, rel = parts
                self.id2path[img_id] = rel

        # 2) carrega image_id → rótulo original (0-based)
        raw_id2label = {}
        with open(os.path.join(root_dir, "image_class_labels.txt"), "r") as f:
            for line in f:
                parts = line.strip().split()
                if len(parts)!=2: continue
                img_id, cls = parts
                raw_id2label[img_id] = int(cls) - 1

        # 3) monta lista crua de samples (antes de remapear)
        flag_target = '1' if split=='train' else '0'
        raw_samples = []
        with open(os.path.join(root_dir, "train_test_split.txt"), "r") as f:
            for line in f:
                parts = line.strip().split()
                if len(parts)!=2: continue
                img_id, flag = parts
                if flag==flag_target and img_id in self.id2path:
                    raw_samples.append((img_id, raw_id2label[img_id]))

        # 4) se for remapear, constrói o mapeamento e aplica
        if remap_labels:
            # pega labels únicos e ordena
            unique_labels = sorted({lbl for _,lbl in raw_samples})
            # cria old->new
            self.label_map = {old: new for new, old in enumerate(unique_labels)}
            # nova lista de samples com labels remapeados
            self.samples = [(img_id, self.label_map[lbl]) for img_id, lbl in raw_samples]
            self.num_classes = len(unique_labels)
        else:
            self.samples = raw_samples
            self.num_classes = max(lbl for _,lbl in raw_samples) + 1

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_id, label = self.samples[idx]
        img = Image.open(os.path.join(self.root, "images", self.id2path[img_id])).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label

data_root = "../dataset/nabirds/versions/1"

train_dataset = NABirdsDataset(
    data_root,
    split='train',
    transform=train_transform
)

val_dataset = NABirdsDataset(
    data_root,
    split='val',
    transform=val_transform
)

print(f"Classes efetivas: {train_dataset.num_classes}")  

## Dataloader

In [None]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=setup["batch_size"],
    shuffle=True,
    num_workers=setup["num_workers"],
    pin_memory=True
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=setup["batch_size"],
    shuffle=False,
    num_workers=setup["num_workers"],
    pin_memory=True
)

# Architecture

In [None]:
weights = models.ResNet50_Weights.IMAGENET1K_V1

net = models.resnet50(weights=weights)

for param in net.parameters():
    param.requires_grad = False

for param in net.layer4.parameters():
    param.requires_grad = True
    
net.fc = nn.Linear(net.fc.in_features, setup["num_classes"])

## View

In [None]:
print(net.fc)

In [None]:
for name, param in net.named_parameters():
    if param.requires_grad == True:
        print(name, param.requires_grad)

In [None]:
summary(net, input_size=(setup['batch_size'], 3, 224, 224))

# Train

In [None]:
def train(net, train_dataloader, val_dataloader, device):
  
    net.to(device)

    optimizer = torch.optim.SGD(
        filter(lambda p: p.requires_grad, net.parameters()),
        lr=setup['lr'],
        weight_decay=setup['weight_decay'],
        momentum=setup['momentum'])

    criterion = setup['criterion']
    criterion.to(device)

    writer = SummaryWriter(log_dir=tensorboard_path)
    writer.add_graph(net, next(iter(train_dataloader))[0].to(device))

    max_accuracy = -1.0

    for epoch in tqdm(range(setup['max_epochs'])):
        
        net.train()  

        train_loss, train_accuracy = [], []

        for train_batch in train_dataloader:
            
            train_x, train_label = train_batch
            train_x = train_x.to(device)
            train_label = train_label.to(device)

            optimizer.zero_grad()

            outs = net(train_x)
    
            loss = criterion(outs, train_label)

            train_loss.append(loss.item())
            
            loss.backward()
            optimizer.step()

            predict_labels = torch.max(outs, axis=1)[1]
            correct = torch.sum(predict_labels == train_label).item()
            accuracy = correct/train_label.size(0)

            train_accuracy.append(accuracy)

            plot_net_attributes(epoch, net, writer)

        train_loss = np.asarray(train_loss)
        train_accuracy = np.asarray(train_accuracy)

        val_loss, val_accuracy = validate(net, criterion, val_dataloader, device)

        writer.add_scalar('Loss/train', train_loss.mean(), epoch)
        writer.add_scalar('Loss/val', val_loss.mean(), epoch)
        writer.add_scalar('Accuracy/train', train_accuracy.mean(), epoch)
        writer.add_scalar('Accuracy/val', val_accuracy.mean(), epoch)
            
        if val_accuracy.mean() > max_accuracy:
            best_model = copy.deepcopy(net)
            max_accuracy = val_accuracy.mean()
            print(f'Saving the model with the best accuracy: {max_accuracy:3.4f}')
            
        print(f'Epoch: {epoch+1:3d} | Loss/train: {train_loss.mean():3.4f}% | Accuracy/train: {train_accuracy.mean():3.4f}% |\
            Loss/val: {val_loss.mean():3.4f}% | Accuracy/val: {val_accuracy.mean():3.4f}% |')

    path = f'{models_path}{setup["experiment"]}-{max_accuracy:.2f}.pkl'
    torch.save(best_model, path)
    print(f'Best model saved in: {path}')

    writer.flush()
    writer.close()
    
    return best_model

# Validate

In [None]:
def validate(net, criterion, val_dataloader, device):

    net.eval()
    net.to(device)

    val_loss, val_accuracy = [], []

    for test_batch in val_dataloader:

        test_x, test_label = test_batch
        test_x = test_x.to(device)
        test_label = test_label.to(device)

        with torch.no_grad():
            outs = net(test_x).detach()

            loss = criterion(outs, test_label)
            
            val_loss.append(loss.item())
    
            predict_labels = torch.max(outs, axis=1)[1]
            correct = torch.sum(predict_labels == test_label).item()
            accuracy = correct/test_label.size(0)
    
            val_accuracy.append(accuracy)
        
    val_loss = np.asarray(val_loss)
    val_accuracy = np.asarray(val_accuracy)

    return val_loss, val_accuracy

# Fit

In [None]:
best_model = train(net, train_dataloader, val_dataloader, device)