In [1]:
import numpy as np 
import matplotlib.pyplot as plt
from tqdm import tqdm as notebook_tqdm

import torch
import torchvision
from torch import nn, optim
from torch.nn import init
from torchvision import datasets, transforms
from accelerate import Accelerator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

<torch._C.Generator at 0x7fcb3a82b3b0>

In [3]:
transformer = transforms.Compose([torchvision.transforms.ToTensor()])
train_loader = torch.utils.data.DataLoader(datasets.MNIST('./files/', train=True, download=True, transform=transformer), batch_size=250, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('./files/', train=False, download=True, transform=transformer), batch_size=250, shuffle=True)

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 10, kernel_size=5, stride=2),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(10, 20, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.Dropout2d(),
            nn.Flatten()
        )
        self.fc1 = nn.Sequential(
            nn.Linear(320, 50),
            nn.ReLU(),
            nn.Dropout()
        )
        self.fc2 = nn.Linear(50, 10)
        
        self.apply(self._init_weights)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return x
    
    def _init_weights(self, m):
        if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
            init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu')
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)

In [5]:
def train(flag='DDP', n_epochs=5):
    device = torch.device('cuda:0')

    model = Net()
    
    if flag == 'DDP':
        model = torch.nn.DataParallel(model).to(device)
    elif flag == 'HFA':
        accelerator = Accelerator()
        model.to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=1e-2)
    loss_fn = torch.nn.CrossEntropyLoss()


    train_losses = []
    test_losses = []
    test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

    highest_acc = 0
    for epoch in range(1, n_epochs + 1):
        train_bar = notebook_tqdm(train_loader)
        model.train()
        for batch_idx, (data, target) in enumerate(train_bar):
            optimizer.zero_grad()
            output = model(data.to(device))
            loss = loss_fn(output, target.to(device))
            
            if flag == 'DDP':
                loss.backward()
            elif flag == 'HFA':
                accelerator.backward(loss)
            
            optimizer.step()

            train_losses.append(loss.item())
            train_bar.set_description(f'Train Epoch: {epoch} Loss: {np.mean(train_losses):.6f}')
        

        model.eval()
        acc = []
        test_bar = notebook_tqdm(test_loader)
        for data, target in test_bar:
            with torch.no_grad():
                output = model(data.to(device))
            acc.extend((torch.argmax(output.softmax(1), dim=1) == target.to(device)).data.cpu().numpy())
            test_bar.set_description(f'Test set: Accuracy: {100. * np.mean(acc):.0f}%')
        
        
        if np.mean(acc) > highest_acc:
            highest_acc = np.mean(acc)

            if flag == 'DDP':
                mstate = model.module.state_dict()
            elif flag == 'HFA':
                mstate = model.state_dict()

            torch.save({'model_state_dict': mstate,
                        'optimizer': optimizer.state_dict(),
                        'epoch': epoch},
                        f'model_{flag}.pt')


In [6]:
train(flag='DDP')

Train Epoch: 1 Loss: 0.732992: 100%|██████████| 240/240 [00:13<00:00, 17.35it/s]
Test set: Accuracy: 96%: 100%|██████████████████| 40/40 [00:01<00:00, 30.79it/s]
Train Epoch: 2 Loss: 0.553780: 100%|██████████| 240/240 [00:11<00:00, 21.72it/s]
Test set: Accuracy: 96%: 100%|██████████████████| 40/40 [00:01<00:00, 30.07it/s]
Train Epoch: 3 Loss: 0.474658: 100%|██████████| 240/240 [00:11<00:00, 21.31it/s]
Test set: Accuracy: 97%: 100%|██████████████████| 40/40 [00:01<00:00, 31.89it/s]
Train Epoch: 4 Loss: 0.427375: 100%|██████████| 240/240 [00:10<00:00, 22.17it/s]
Test set: Accuracy: 97%: 100%|██████████████████| 40/40 [00:01<00:00, 31.05it/s]
Train Epoch: 5 Loss: 0.393820: 100%|██████████| 240/240 [00:11<00:00, 21.18it/s]
Test set: Accuracy: 97%: 100%|██████████████████| 40/40 [00:01<00:00, 35.00it/s]


In [7]:
train(flag='HFA')

Train Epoch: 1 Loss: 0.611305: 100%|██████████| 240/240 [00:11<00:00, 21.56it/s]
Test set: Accuracy: 96%: 100%|██████████████████| 40/40 [00:01<00:00, 37.20it/s]
Train Epoch: 2 Loss: 0.458234: 100%|██████████| 240/240 [00:11<00:00, 21.35it/s]
Test set: Accuracy: 97%: 100%|██████████████████| 40/40 [00:01<00:00, 36.81it/s]
Train Epoch: 3 Loss: 0.390392: 100%|██████████| 240/240 [00:11<00:00, 21.60it/s]
Test set: Accuracy: 97%: 100%|██████████████████| 40/40 [00:01<00:00, 32.60it/s]
Train Epoch: 4 Loss: 0.352052: 100%|██████████| 240/240 [00:11<00:00, 21.70it/s]
Test set: Accuracy: 98%: 100%|██████████████████| 40/40 [00:01<00:00, 34.53it/s]
Train Epoch: 5 Loss: 0.326551: 100%|██████████| 240/240 [00:10<00:00, 21.91it/s]
Test set: Accuracy: 98%: 100%|██████████████████| 40/40 [00:01<00:00, 36.18it/s]
