In [8]:
import torch
import numpy as np
# import matplotlib.pyplot as plt

# Basic pytorch
import torch.nn as nn
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# DDP
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

# from torchinfo import summary
import gc
import os

In [6]:
class CNN(nn.Module):
    def __init__(self, num_classes=525, dropout=0.5):
        super(CNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), padding='same'),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2), stride=2),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3,3), padding='same'),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3,3), padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2), stride=2),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3,3), padding='same'),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3,3), padding='same'),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3,3), padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2), stride=2),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3,3), padding='same'),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3,3), padding='same'),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3,3), padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2), stride=2),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3,3), padding='same'),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3,3), padding='same'),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3,3), padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2), stride=2)
        )
        self.adaptive = nn.AdaptiveAvgPool2d((7,7))
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(p=dropout),
            nn.Linear(512*7*7, 4096),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(4096, num_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.adaptive(x)
        x = self.fc(x)
        return x

In [11]:
class DDPTrainer:
    def __init__(
        self,
        model: nn.Module, 
        train_loader: DataLoader,
        val_loader: DataLoader,
        optimizer: torch.optim,
        loss_fn,
        save_freq: int,
        save_path: str):
        
        self.local_rank = int(os.environ['LOCAL_RANK'])
        self.global_rank = int(os.environ['RANK'])
        self.model = model.to(self.local_rank)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.loss_fn = loss_fn
    
        self.save_freq = save_freq
        self.epochs_run = 0
        
        if os.path.exists(save_path):
            print('Loading Saved State')
            self.load_saved_state(save_path)
        
        self.model = DDP(self.model, device_ids=[self.local_rank])
        
    def load_saved_state(self, save_path):
        old_state = torch.load(save_path)
        self.model.load_state_dict(old_state['MODEL_STATE'])
        self.epochs_run = old_state['EPOCHS_RUN']
        
    def train_iter(self):
        for X, y in self.train_loader:
            X, y = X.to(self.local_rank), y.to(self.local_rank)
            self.optimizer.zero_grad()
            pred = self.model(X)
            loss = self.loss_fn(pred, y)
            loss.backward()
            self.optimizer.step()

            del X, y, pred, loss
            gc.collect()
            torch.cuda.empty_cache()
    
    def validate(self):
        with torch.no_grad():
            size = len(self.val_loader)
            loss = 0
            acc = 0
            for X, y in self.val_loader:
                X, y = X.to(self.local_rank), y.to(self.local_rank)
                pred = self.model(X)
                loss += self.loss_fn(pred, y)
                acc += (pred.argmax(1) == y).type(torch.float).sum().item()

                del X, y, pred
                gc.collect()
                torch.cuda.empty_cache()
                
            acc /= size
            loss /= size
            print('Validation Accuracy:', str(acc))
            print('Validation Loss:', str(loss))
            
    def save_state(self, epoch: int, save_path: str):
        state = {}
        state['MODEL_STATE'] = self.model.module.state_dict()
        state['EPOCHS_RUN'] = epoch
        torch.save(state, save_path)
        print(f'Epoch {epoch}, saving model at {save_path}')
            
    def train_loop(self, epochs: int, save_path: str):
        for e in range(self.epochs_run, epochs):
            print(f'EPOCH {e} from GPU {self.global_rank}')
            self.train_iter()
            if self.local_rank == 0 and e % self.save_freq == 0:
                self.save_state(e, save_path)

        self.save_state(e, save_path)
    

In [8]:
# model = CNN().to(device)
# summary(model, input_size=(1,3,224,224))

Layer (type:depth-idx)                   Output Shape              Param #
CNN                                      [1, 525]                  --
├─Sequential: 1-1                        [1, 512, 7, 7]            --
│    └─Conv2d: 2-1                       [1, 64, 224, 224]         1,792
│    └─ReLU: 2-2                         [1, 64, 224, 224]         --
│    └─Conv2d: 2-3                       [1, 64, 224, 224]         36,928
│    └─ReLU: 2-4                         [1, 64, 224, 224]         --
│    └─MaxPool2d: 2-5                    [1, 64, 112, 112]         --
│    └─Conv2d: 2-6                       [1, 128, 112, 112]        73,856
│    └─ReLU: 2-7                         [1, 128, 112, 112]        --
│    └─Conv2d: 2-8                       [1, 128, 112, 112]        147,584
│    └─ReLU: 2-9                         [1, 128, 112, 112]        --
│    └─MaxPool2d: 2-10                   [1, 128, 56, 56]          --
│    └─Conv2d: 2-11                      [1, 256, 56, 56]          29

In [None]:
print("Made it here")
init_process_group(backend='nccl')
torch.cuda.set_device(int(os.environ['LOCAL_RANK']))

In [None]:
train_path = './bird_dataset/train/'
val_path = './bird_dataset/valid/'
test_path = './bird_dataset/test/'
batch_size = 64

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

train_dataset = datasets.ImageFolder(root=train_path, transform=transform)
val_dataset = datasets.ImageFolder(root=val_path, transform=transform)
test_dataset = datasets.ImageFolder(root=test_path, transform=transform)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size, 
    pin_memory=True, 
    shuffle=False, 
    sampler=DistributedSampler(train_dataset)
)
val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=batch_size, 
    pin_memory=True, 
    shuffle=False, 
    sampler=DistributedSampler(val_dataset)
)
test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size, 
    pin_memory=True, 
    shuffle=False, 
    sampler=DistributedSampler(test_dataset)
)

In [None]:
learning_rate = 1e-2
loss_fn = nn.CrossEntropyLoss()
model = CNN()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
epochs = 100

save_freq = 5
save_path = 'checkpoint.pt'

DDP_model = DDPTrainer(
    model=model, 
    train_loader=train_loader, 
    val_loader=val_loader, 
    optimizer=optimizer, 
    loss_fn=loss_fn, 
    save_freq=save_freq,
    save_path=save_path
)

DDP_model.validate()
# DDP_model.train_loop(epochs, save_path)


In [None]:
destroy_process_group()