In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import numpy as np
from tqdm import tqdm

import datasets_classification as dataset
import resnet

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
train_size = 100
use_augmentation = True
augmentation_type = "basic"
augmentation_size = 50
model_path = f"model_{augmentation_type}_aug_{augmentation_size}_train_{train_size}.pt"

In [None]:
images_dir = "../../data/vinbig_data/"
# get datasets
train_set = dataset.get_train_data(images_dir,
                                   train_size = train_size,
                                   use_augmentation = use_augmentation,
                                   augmentation_type = augmentation_type,
                                   augmentation_size = augmentation_size)
val_set = dataset.get_val_data(images_dir)
test_set = dataset.get_test_data(images_dir)

# setup dataloaders
batch_size = 8
train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle = True)
val_dataloader = DataLoader(val_set, batch_size = batch_size, shuffle = True)
test_dataloader = DataLoader(test_set, batch_size = 1, shuffle = True)

In [None]:
# load model and setup training parameters
model = resnet.ResNet_18(image_channels=1, num_classes=15)
model = model.to(device)

epochs = 10
val_check = 10
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)

In [None]:
# train and eval the model
def save_ckpt(model, path):
    torch.save(model.state_dict(), os.path.join("classification", path)) 
    return

def get_metric(pred, target, thresh=0.5):
    pred = np.array(pred > thresh, dtype=float)
    return f1_score(y_true = target, y_pred = pred, average='samples')

def test(model, test_dataloader):
    model = model.to(device)
    macro_f1 = []
    model.eval()
    with torch.no_grad():
        for data, targets in test_dataloader:
            data = data.to(device)
            targets = targets.type(torch.int8).to(device)

            pred = model(data)
            macro_f1.append(get_metric(pred.cpu(), targets.cpu()))
    
    return np.mean(macro_f1)

def validate(model, val_dataloader):
    val_losses = []
    
    model.eval()
    for (data, targets) in val_dataloader:
        data = data.to(device)
        targets = targets.to(device)

        pred = model(data)
        loss = criterion(pred.type(torch.float), targets.type(torch.float))
        val_losses.append(loss.item())
    
    return np.mean(val_losses)

def train(model, train_dataloader, val_dataloader, save_path):
    print(f"Training for model: {save_path}")
    
    best_loss = np.inf
    for epoch in tqdm(range(epochs)):
        batch_losses = []
        for (imgs, targets) in train_dataloader:
            imgs, targets = imgs.to(device), targets.to(device)

            optimizer.zero_grad()

            pred = model(imgs)
            loss = criterion(pred.type(torch.float), targets.type(torch.float))

            batch_loss_value = loss.item()
            loss.backward()
            optimizer.step()

            batch_losses.append(batch_loss_value)

        loss = np.mean(batch_losses)
        print(f"Epoch: {epoch}")
        print(f"[TRAIN] Loss: {loss:3f}")
        
        if epoch % val_check == 0:
            val_loss = validate(model, val_dataloader)
        
            print(f"[VALIDATION] Loss: {val_loss:3f}\n")

            model.train()
            if val_loss < best_loss:
                best_loss = val_loss
                save_ckpt(model, save_path)

## Training

In [None]:
train(model, train_dataloader, val_dataloader, model_path)

In [3]:
torch.cuda.empty_cache()

## Testing

In [None]:
model_names = sorted([i for i in os.listdir('classification') if i.endswith('.pt')])
model = resnet.ResNet_18(image_channels=1, num_classes=15)

for model_weights_path in model_names:
    weights = torch.load(model_weights_path)
    model.load_state_dict(weights)
    test_score = test(model, test_dataloader)
    
    print(model, test_score)