In [None]:
import torch as t, torch.nn as nn, torch.nn.functional as F, torch.distributions as tdist
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import random_split
import torchvision as tv, torchvision.transforms as tr
import os
import sys
import numpy as np
#import wideresnet # from The Google Research Authors
import json
from torchvision import datasets
from pathlib import Path
#from wrn import WRN
from zntrack import ZnTrackProject, Node, config, dvc, zn
from zntrack.utils.decorators import check_signature

# config.nb_Name must be the notebook file name so that ZnTrack can generate the associated src/ .py scripts
config.nb_name = "TrainNN.ipynb"
project = ZnTrackProject()

In [None]:
class WRN(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  #input tensors [x, 3, 32, 32]
            nn.Softplus(), #softplus is a different non-linear activation function, similar to ReLU
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  #input [ x, 32, 32, 32]
            nn.Softplus(),  #output [x, 64, 32, 32]
            nn.MaxPool2d(2, 2),  #output [x, 64, 16, 16]
            nn.BatchNorm2d(64), #, eps=1e-05, momentum=0.3, affine=True, track_running_stats=True),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), #out [1, 128, 16, 16]
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), #out [1, 128, 16, 16]
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 128 x 8 x 8
            nn.BatchNorm2d(128),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), # out [1, 256, 8, 8]
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), # out [1, 256, 8, 8]
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 256 x 4 x 4
            nn.BatchNorm2d(256),
            
            nn.Flatten(),
            nn.Linear(256*4*4, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )
        
    def forward(self, xb):
        return self.network(xb)

In [None]:
class Base:
    def compute(self, inp):
        raise NotImplementedError

In [None]:
# How to define ML training workflow through ZnTrack
# 1. create computation class extending Base, implement compute function

class Trainer(Base):
    @check_signature
    def __init__(self):
        self.model = WRN()
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = t.optim.Adam(self.model.parameters(), lr=1e-3,weight_decay=0.0005)
        #self.dataloader = dataloader
        self.device = t.device('cuda' if t.cuda.is_available() else 'cpu')
        
    def train(self, model):
        transform_normal = tr.Compose([tr.ToTensor(), tr.Normalize((.49, .48, .44), (.24, .24, .26))])
        normal_train = tv.datasets.CIFAR10(root='root', transform=transform_normal, download=False, train=True)
        train_ds, val_ds = random_split(normal_train, [45000, 5000])
        batch_size = 64 
        dataloader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)
        #val_dl = DataLoader(val_ds, batch_size, num_workers=4, pin_memory=True)
        loss_fn = nn.CrossEntropyLoss()
        optimizer = t.optim.Adam(model.parameters(), lr=1e-3,weight_decay=0.0005)
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        #self.dataloader = train_dl
        
        size = len(dataloader.dataset)
        model.train()
        for batch, (X, y) in enumerate(dataloader):
            #X is the batch of images, y is the vector of numeric labels for them
            X, y = X.to(self.device), y.to(self.device)
            pred = model(X)
            loss = loss_fn(pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if batch % 100 == 0:
                #print(y)
                #print(pred)
                loss, current = loss.item(), batch * len(X)
                print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
            
            
    def test(self, model):
        transform_normal = tr.Compose([tr.ToTensor(), tr.Normalize((.49, .48, .44), (.24, .24, .26))])
        normal_train = tv.datasets.CIFAR10(root='root', transform=transform_normal, download=False, train=True)
        train_ds, val_ds = random_split(normal_train, [45000, 5000])
        batch_size = 64 
        #train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)
        dataloader = DataLoader(val_ds, batch_size, num_workers=4, pin_memory=True)
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        model.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(self.device), y.to(self.device)
                pred = model(X)
                test_loss += self.loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= num_batches
        correct /= size
        acc = 100*correct
        print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
        return acc, test_loss   
    
    def dotraining(self, epochs):
        #model = self.model
        #model
        test_scores = {}
        for step in range(epochs):
            print(f"Epoch {step+1}\n-------------------------------")
            test_scores[step] = {}
            self.train(self.model)
            testAcc, testLoss = self.test(self.model)
            test_scores[step] = {"acc:": float(testAcc), "loss": float(testLoss)}
        with open( self.resultsfile, 'w') as outfile:
            json.dump(test_scores, outfile)
        
    
    def compute(self, inp):
        #raise NotImplementedError
        self.dotraining(inp)

In [None]:
# 2.  Create Node class, implement __call__ and run, optionally __init__
# use the zn.Method() type for the Base class extention 
# this builds the py script, so run after any changes at all

@Node()
class Train:
    epochs = dvc.params()
    trainer: Base = zn.Method()
    resultsfile: Path = dvc.outs()
    result = zn.outs()
    
    def __init__(self):
        self.model = WRN()
        
    def __call__(self, epochs, resultsfile: Path, trainer):
        self.trainer = trainer
        self.epochs = epochs
        self.resultsfile = resultsfile
        self.resultsfile.mkdir(exist_ok=True, parents=True)
        
    def run(self):
        self.result = self.trainer.compute(inp=self.epochs)


In [None]:
# 3.  instantiate classes, pass the base compute class into the node class 
# this creates and writes the dvc.yaml stages 

trainer = Trainer()

train_stage = Train()
train_stage(epochs=3, resultsfile=Path("outs", "test_scores.json"), trainer=trainer)

In [None]:
#!dvc repro
# equivalent to this
project.repro()