In [2]:
import warnings
warnings.filterwarnings('ignore')
from glob import glob 
import os 
import random 

import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt 
from PIL import Image 
import cv2


import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.utils.data import Dataset,DataLoader
import torchvision.transforms as transforms 
import torchvision
import timm 
import wandb

from src.data.augmentation import *
from src.data.factory import create_dataset,create_dataloader
from src.options import Options
from src.models import Model 
from src.loss_function import LossFunction
from src.train import AverageMeter,torch_seed,train_epoch,valid_epoch
from src.callback import Callbacks

In [3]:
def run(cfg):
    torch_seed(cfg['seed'])
    
    # build train,test loader 
    trainset,testset = create_dataset(cfg)
    train_loader = create_dataloader(
                        dataset    = trainset,
                        batch_size = cfg['Batchsize'],
                        shuffle    = True)
    test_loader = create_dataloader(
                        dataset    = testset,
                        batch_size = cfg['Batchsize'],
                        shuffle    = True)

    # build a model, criterion and optimizer 
    model = Model(cfg['modeltype']).to(cfg['device'])
    criterion = LossFunction()
    optimizer = __import__('torch.optim', fromlist='optim').__dict__['Adam'](model.parameters(), lr=cfg['lr'],betas=(cfg['beta1'],0.999))
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg['nepochs'])
    
    print('All loaded, Training start')
    #fit(model,train_loader,test_loader,criterion,optimizer,scheduler,cfg)
    return model,train_loader,test_loader,criterion,optimizer,scheduler
    
def fit(model,train_loader,test_loader,criterion,optimizer,scheduler,cfg):
    
    callbacks = Callbacks(cfg)
    
    total_loss = {} 
    total_loss['train'] = [] 
    total_loss['valid'] = [] 
    
    for epoch in range(cfg['nepochs']):
        train_loss = train_epoch(model,train_loader,criterion,optimizer,cfg)
        valid_loss = valid_epoch(model,test_loader,criterion,cfg)

        
        total_loss['train'].append(train_loss)
        total_loss['valid'].append(valid_loss)
        scheduler.step()
        
        log = {'Epoch' : epoch,
               'train_loss' : train_loss,
               'valid_loss' : valid_loss,
               'learing_rate' : optimizer.param_groups[0]['lr']}
        
        #check point 
        callbacks.epoch(model,log)
    callbacks.epoch(model,log,'last')

In [1]:
import yaml 
with open('./configs/default.yaml','r') as f:
    cfg = yaml.load(f,Loader=yaml.FullLoader)
    

In [2]:
cfg['transform']

'no_augmentation'