In [1]:
import os
import random
import time
# Torch
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data.sampler import SubsetRandomSampler

# Torchvison
import torchvision.transforms as T
import torchvision.models as models
# from torchvision.datasets import CIFAR100, CIFAR10

# Utils
# import visdom
from tqdm import tqdm

# Custom
from config import *
from ace_data import train_data, test_data

import timm
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score

import datetime

model = timm.create_model(model_name = 'resnet50',weights = None, num_classes = 4)


In [2]:
train_loader = DataLoader(train_data, batch_size=BATCH, 
                                num_workers = 4, pin_memory=True,  drop_last = True)
test_loader  = DataLoader(test_data,num_workers = 4, batch_size=1)

models      = {'backbone': model }
dataloaders  = {'train': train_loader, 'test': test_loader}

In [3]:
train_loader = DataLoader(train_data, batch_size=BATCH, 
                                num_workers = 4, pin_memory=True,  drop_last = True)
test_loader  = DataLoader(test_data,num_workers = 4, batch_size=1)

models      = {'backbone': model }
dataloaders  = {'train': train_loader, 'test': test_loader}

criterion      = nn.CrossEntropyLoss(reduction='none')
optim_backbone = optim.SGD(models['backbone'].parameters(), lr=LR)
sched_backbone = lr_scheduler.MultiStepLR(optim_backbone, milestones=MILESTONES)

optimizers = {'backbone': optim_backbone}
schedulers = {'backbone': sched_backbone}
        

# Train and Test 

In [8]:
def train_epoch(models, criterion, optimizers, dataloaders, epoch, epoch_loss, plot_data=None):
    models['backbone'].train()
    
    
    for data in tqdm(dataloaders['train'], leave=False, total=len(dataloaders['train'])):
        inputs = data[0].to(device)
        labels = data[1].to(device)
        

        optimizers['backbone'].zero_grad()

        
        scores = models['backbone'](inputs)
        target_loss = criterion(scores, labels)
        

        m_backbone_loss = torch.sum(target_loss) / target_loss.size(0)
        
        loss            = m_backbone_loss

        loss.backward()
        optimizers['backbone'].step()
        
    
def test(models, dataloaders, mode='val'):
    metrics = {}
    preds, labels = [], []
    assert mode == 'val' or mode == 'test'
    models['backbone'].eval()
    

    with torch.no_grad():
        for (inputs, label) in dataloaders[mode]:
            inputs = inputs.to(device)
            label = label.to(device)

            scores, _ = models['backbone'](inputs)
            # _, preds = torch.max(scores.data, 1)
            labels.extend(label.detach().tolist())
            preds.extend(scores.argmax(axis=1).detach().tolist())
        # print(f" labels = {labels}")
        # print(f" preds = {preds}")
    metrics['accuracy'] = accuracy_score(y_pred=preds, y_true=labels)
    metrics['f1_score'] = f1_score(y_pred=preds, y_true=labels, average='weighted')
    metrics['precision'] = precision_score(y_pred=preds, y_true=labels, average='weighted')
    metrics['recall'] = recall_score(y_pred=preds, y_true=labels, average='weighted')

            # print("labels " , labels)
            # print("preds " , preds)
    return metrics

#
def train(models, criterion, optimizers, schedulers, dataloaders, num_epochs, epoch_loss):
    print('>> Train a Model.')
    
    
    for epoch in tqdm(range(num_epochs)):
            
        train_epoch(models, criterion, optimizers, dataloaders, epoch, epoch_loss)
        schedulers['backbone'].step()

    print('>> Finished.')


In [9]:
train(models, criterion, optimizers, schedulers, dataloaders, EPOCH, EPOCHL)
metrics = test(models, dataloaders, mode='test')

print(f"accuracy : {metrics['accuracy']:.4f}")
print(f"f1 score : {metrics['f1_score']:.4f}")
print(f"precision : {metrics['precision']:.4f}")
print(f"recall : {metrics['recall']:.4f}")

>> Train a Model.


  0%|          | 0/200 [00:00<?, ?it/s]

In [2]:
import numpy as np

b = np.zeros(10)

print(b)

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
