In [None]:
import torch
import time as t 

In [None]:
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

In [None]:
class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [None]:
def evaluate(model, val_loader):
    """Evaluate the model's performance on the validation set"""
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

def fit(epochs, lr, mo, model, train_loader, val_loader, opt_func=torch.optim.SGD):
    """Train the model using gradient descent"""
    history = []
    optimizer = opt_func(model.parameters(), lr, mo)
    for epoch in range(epochs):
        t0 = t.time() 
        # Training Phase 
        for batch in train_loader:
            loss = model.training_step(batch)
            loss.backward(retain_graph=True)
            #print(model.layers[1].activation.a) 
            optimizer.step()
            #print(model.layers[1].activation.a) 
            optimizer.zero_grad()
        # Validation phase
        result = evaluate(model, val_loader)
        result['epoch_time'] = t.time() - t0 
        model.epoch_end(epoch, result)
        history.append(result)
    return history

In [None]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))