In [86]:
import torch
import torch.optim as optim
from torchsummary import summary
from tqdm import tqdm
import torch.nn as nn
import torchvision.transforms as transforms


In [87]:
from torch.utils.data import Dataset,DataLoader

In [88]:
import model_skel

In [89]:
model=model_skel.model_skel(16,32,16)

In [96]:
class dataset_skel(Dataset):
    import random
    def __init__(self,size):
        self.rand_data=torch.rand(size,requires_grad=True)
    def __len__(self):
        return len(self.rand_data)
    def __getitem__(self,idx):
        return self.rand_data[idx],self.rand_data[idx]

In [97]:
train_ds=dataset_skel((32,16))
valid_ds=dataset_skel((16,16))

In [98]:
train_dl=DataLoader(train_ds,batch_size=8)
valid_dl=DataLoader(valid_ds,batch_size=8)

In [99]:
config_=dict()
config_["lr"]=0.3
config_["crit"]=nn.L1Loss()
config_["optim"]=optim.Adam(model.parameters(),lr=config_["lr"])
config_["save_dir"]="models/"
config_["epochs"]=10
config_["train_log"]=[]
config_["valid_log"]=[]
config_["log_interval"]=4
config_["device"]="cpu"

In [100]:
def train(config,train_dl,valid_dl,model):
    '''
    config.optimizer: Training optimizer eg) Adam,SGD....
    config.criterion: Loss eg) CrossEntropy, 
    config.lr: learning_rate
    train_dl,valid_dl: Dataloader
    model: target model
    save_path
    '''
    
    for i in range(config["epochs"]):
        print(f'Epochs : {i+1}')
        model.train()
        avg_train_loss=0
        for batch_idx,(img,label) in enumerate(train_dl):
            #Zero_grad_Optimizer
            config["optim"].zero_grad()
            
            output=model(img)
            
            #Measure Loss
            loss=config["crit"](output,label)
            loss.backward()
            
            #Update Parameters
            config["optim"].step()
            avg_train_loss+=loss.item()
            if batch_idx % config["log_interval"]:
                config["train_log"].append(avg_train_loss/(batch_idx+1))
                print(loss.item())
        model.eval()
        valid_loss=0
        for batch_idx,(img_label) in enumerate(valid_dl):
            output=model(img)
            #Measure Loss
            loss=config["crit"](output,label)
            #Update Parameters
            valid_loss+=loss.item()
        config["valid_log"].append(valid_loss/(batch_idx+1))
    return config,model
            
    

In [101]:
config_,model=train(config_,train_dl,valid_dl,model)

Epochs : 1
0.48571857810020447
0.47112151980400085
0.4897589087486267
Epochs : 2
0.4857294261455536
0.47112151980400085
0.4897589087486267
Epochs : 3
0.4857294261455536
0.47112151980400085
0.4897589087486267
Epochs : 4
0.4857294261455536
0.47112151980400085
0.4897589087486267
Epochs : 5
0.4857294261455536
0.47112151980400085
0.4897589087486267
Epochs : 6
0.4857294261455536
0.47112151980400085
0.4897589087486267
Epochs : 7
0.4857294261455536
0.47112151980400085
0.4897589087486267
Epochs : 8
0.4857294261455536
0.47112151980400085
0.4897589087486267
Epochs : 9
0.4857294261455536
0.47112151980400085
0.4897589087486267
Epochs : 10
0.4857294261455536
0.47112151980400085
0.4897589087486267
