In [1]:
!pip install torch
!pip install torchvision



In [2]:
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from mls import mls

In [3]:
train_dataset = mls.CustomDataSet("./dataset/processed/train_data.csv")
val_dataset = mls.CustomDataSet("./dataset/processed/test_data.csv")

In [4]:
batch_size = 16
shuffle_dataset = True
num_epochs = 150
in_features = len(train_dataset[0][0])
out_classes = 5
save_model = False
save_interval = 10
model_p = './checkpoints/'

In [5]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=True)

In [6]:
model = mls.MVC(in_features).double()
learning_rate = 0.1
loss_f = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min')

In [10]:
train_loss = []
train_acc = []
val_loss = []
val_acc = []
lrs = []
best_loss = np.inf
for epoch in range(2):
    t_loss, t_acc = mls.training_epoch(model, optimizer, loss_f, train_loader)
    t_loss = t_loss/len(train_dataset)
    t_acc = 100*(t_acc/len(train_dataset))
    
    v_loss, v_acc = mls.validation_epoch(model, loss_f, validation_loader)
    v_loss = v_loss/len(val_dataset)
    v_acc = 100*(v_acc/len(val_dataset))
    
    train_loss.append(t_loss)
    val_loss.append(v_loss)
    train_acc.append(t_acc)
    val_acc.append(v_loss)
    lrs.append(optimizer.param_groups[0]['lr'])
    scheduler.step(v_loss)
    
    if epoch%save_interval==0 and save_model:
        isBest = v_loss<best_loss
        mls.save_model(model_p, epoch, model, optimizer, v_loss, isBest) 
        
    print(f'e{epoch}: Train Loss: {t_loss:.04f}; Val Loss: {v_loss:.04f} acc: {v_acc:.04f}')
        

TypeError: unsupported format string passed to list.__format__

In [None]:
with open(f'{model_p}/log.json', 'w') as fout:
    training_info ={
        'train_loss':train_loss,
        'train_acc':train_acc,
        'val_loss':val_loss,
        'val_acc':val_acc,
        'lrs':lrs,
    }
    json_dumps_str = json.dumps(training_info, indent=4)
    print(json_dumps_str, file=fout)