In [None]:
import torch

import mlflow

import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from baseline import BaselineModel

In [None]:
mlflow.set_tracking_uri(mlflow_uri)
mlflow.set_experiment('cifar10-classification')

In [None]:
batch_size=4
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)

validation_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
baseline_model = BaselineModel()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
baseline_model.to(device)

In [None]:
n_epochs = 3
lr = 0.001
criterion=nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(baseline_model.parameters(), lr=lr)

In [None]:
mlflow.log_params({
    'n_epochs':n_epochs,
    "learning_rate":lr,
    "training_set": len(train_loader),
    "validation_set": len(validation_loader)
})

In [None]:
loss_p = np.array([])
accuracy_p = np.array([])
for epoch in range(n_epochs):
    
    #training
    baseline_model.train()
    total_image = 0
    correct_image = 0
    running_loss = 0
    for i, data in enumerate(train_loader):
        image, label = data
        optimizer.zero_grad()
        
        output = baseline_model(image)
        __, predicts = torch.max(output, 1)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        
        total_image += label.size(0)
        correct_image += (predicts == label).sum().item()
        running_loss += loss.item()
        if i%100==0:
            print('batch:{}/{}, accuracy:{}'.format(i, len(train_loader), correct_image/total_image*100), end='\r')
    print('Epoch:{}, loss:{}, accuracy:{}'.format(
        epoch+1, running_loss/len(train_loader), correct_image/total_image*100
    ))
    loss_p=np.append(loss_p, running_loss/len(train_loader))
    accuracy_p=np.append(accuracy_p, correct_image/total_image*100)
    mlflow.log_metric("train_loss", running_loss/len(train_loader))
    mlflow.log_metric("train_accuracy", correct_image/total_image*100)
    
    #validataion
    baseline_model.eval()
    validation_loss = 0
    validation_total_image = 0
    validation_correct_image = 0
    for i, data in enumerate(validation_loader):
        validation_image, validation_label = data
        with torch.no_grad():
            output = baseline_model(image)
            __, predict = torch.max(output, 1)
            loss = criterion(output, validation_label)
            
            validation_total_image += validation_label.size(0)
            validation_correct_image += (predict == validation_label).sum().item()
            validation_loss += loss
            
    mlflow.log_metric("validation_loss", validation_loss/len(validation_loader))
    mlflow.log_metric("validation_accuracy", validation_correct_image/validation_total_image*100)

In [None]:
mlflow.pytorch.log_model(
    baseline_model, 
    artifact_path='cifar10-classifier', 
    registered_model_name="cifar10-classifier"
)

In [None]:
mlflow.end_run()