In [1]:
! pip install mpi4py

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting mpi4py
  Downloading mpi4py-3.1.4.tar.gz (2.5 MB)
[K     |████████████████████████████████| 2.5 MB 23.4 MB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Building wheels for collected packages: mpi4py
  Building wheel for mpi4py (PEP 517) ... [?25l[?25hdone
  Created wheel for mpi4py: filename=mpi4py-3.1.4-cp38-cp38-linux_x86_64.whl size=4438394 sha256=bfb793f5704cfe9230bf571b29f5395b9f4e342af68476241df05e37a9b39e33
  Stored in directory: /root/.cache/pip/wheels/f3/35/48/0b9a7076995eea5ea64a7e4bc3f0f342f453080795276264e7
Successfully built mpi4py
Installing collected packages: mpi4py
Successfully installed mpi4py-3.1.4


In [6]:
%%writefile DataSet.py
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
import numpy as np
train_data = datasets.MNIST(
      root = 'data',
      train = True,                         
      transform = ToTensor(), 
      download = True,            
  )

test_data = datasets.MNIST(
    root = 'data', 
    train = False, 
    transform = ToTensor()
)

Overwriting DataSet.py


In [7]:
%%writefile DataLoader.py
from torch.utils.data import DataLoader
import torch
from DataSet import train_data,test_data
def DataLoader(batchSize, numWorkers,shuffle = False,):
  loaders = {
  'train' : torch.utils.data.DataLoader(train_data, 
                                          batch_size=batchSize, 
                                          shuffle=shuffle, 
                                          num_workers=numWorkers),
  
  'validate'  : torch.utils.data.DataLoader(test_data, 
                                          batch_size=batchSize, 
                                          shuffle=True, 
                                          num_workers=numWorkers)
  }
  return loaders     


Overwriting DataLoader.py


In [4]:
%%writefile Model.py
import torch.nn as nn
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,              
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(16, 32, 5, 1, 2),     
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output

Writing Model.py


In [10]:
%%writefile CursMain.py
from mpi4py import MPI
import torch.optim as optim
import torch
from DataLoader import DataLoader
from torch.autograd import Variable
from Model import CNN
import torch.nn as nn
from tqdm import tqdm
def train_model(num_epochs,
                criterion,
                test_dataloader,
                rank,
                batch_size,
                optimizer,
                model,
                train_dataloader,
                val_dataloader):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model_best = 0
    for epoch in range(num_epochs):
      running_loss = 0
      accuracy = 0

      dataset_sizes_train = len(train_dataloader)
      model.train()
      if rank == 0:
        i = 0
        for images, labels in tqdm(train_dataloader):
          comm.send((images, labels),dest = i%(p-1)+1)
          i+=1
      if rank != 0:
        for i in range(len(train_dataloader)):    
          if i % (p - 1) + 1 == rank:
            (images,lables) = comm.recv(source=0)
            images = images.to(device)
            lables = lables.to(device)

            output = model(images)
            loss = criterion(output,lables)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
        running_loss = running_loss / dataset_sizes_train
        print("Epoch of train:", epoch + 1,"Loss: [", running_loss, "]", "rank: ", rank)
            
      MPI.Comm.Barrier(MPI.COMM_WORLD)
      
      accuracy = 0
      validate_loss = 0.0
      dataset_sizes_val = len(val_dataloader)
      if rank != 0:
        model.eval()
      if rank == 0:
        i = 0
        for images, labels in tqdm(val_dataloader):
          comm.send((images,labels), dest=i % (p - 1) + 1)
          i+=1
      if rank != 0:
        for i in range(len(val_dataloader)):    
          if i % (p - 1) + 1 == rank:
            (images,lables) = comm.recv(source=0)
            images = images.to(device)
            lables = lables.to(device)
            with torch.no_grad():
              output = model(images)
              loss = criterion(output,lables)
              validate_loss += loss.item() * images.size(0)
              pred_y = torch.max(output, 1)[1].data.squeeze()
        validate_loss = validate_loss / dataset_sizes_val
        print("Epoch of validation:", epoch + 1,"Loss: ",validate_loss, rank)
      MPI.Comm.Barrier(MPI.COMM_WORLD) 
      if rank != 0: 
        if epoch == 0:
          model_best = validate_loss
        if validate_loss <= model_best:
          model_best = validate_loss
          torch.save(model.state_dict(), f"./weights/model_{rank}.pth")
    return model

if __name__ == "__main__":
  comm = MPI.COMM_WORLD
  my_rank = comm.Get_rank()
  p = comm.Get_size()
  num_epochs = 3
  batch_size = 30
  num_workers = 4
  train_dataloader = DataLoader(batch_size,num_workers)['train']
  validate_dataloader = DataLoader(batch_size,num_workers)['validate']
  model = CNN()
  optimizer = optim.Adam(model.parameters(), lr=0.001)
  criterion = nn.CrossEntropyLoss()
  model = train_model(num_epochs,criterion,validate_dataloader,my_rank,
                           batch_size,optimizer,model,train_dataloader,validate_dataloader)
  MPI.Finalize                         


Overwriting CursMain.py


In [11]:
! mpirun --allow-run-as-root -np 4 python CursMain.py

 99%|█████████▉| 1982/2000 [00:20<00:00, 129.98it/s]Epoch of train: 1 Loss: [ 3.174856982465135 ] rank:  3
100%|█████████▉| 1999/2000 [00:20<00:00, 141.31it/s]Epoch of train: 1 Loss: [ 2.8188923583982977 ] rank:  1
Epoch of train: 1 Loss: [ 2.93712607084075 ] rank:  2
100%|██████████| 2000/2000 [00:20<00:00, 97.10it/s] 
 98%|█████████▊| 326/334 [00:02<00:00, 147.20it/s]Epoch of validation: 1 Loss:  0.9887085081507524 2
Epoch of validation: 1 Loss:  1.1480538361541324 3
Epoch of validation: 1 Loss:  0.9974066961227643 1
100%|██████████| 334/334 [00:02<00:00, 130.95it/s]
 99%|█████████▉| 1983/2000 [00:15<00:00, 126.77it/s]Epoch of train: 2 Loss: [ 0.9541301033759373 ] rank:  3
100%|█████████▉| 1998/2000 [00:15<00:00, 132.07it/s]Epoch of train: 2 Loss: [ 0.8603720277588581 ] rank:  1
Epoch of train: 2 Loss: [ 0.8783133789914427 ] rank:  2
100%|██████████| 2000/2000 [00:15<00:00, 128.25it/s]
 93%|█████████▎| 312/334 [00:02<00:00, 158.94it/s]Epoch of validation: 2 Loss:  0.7483845197562196 

In [83]:
%%writefile Test.py
from mpi4py import MPI
import torch.optim as optim
import torch
from DataLoader import DataLoader
from torch.autograd import Variable
from Model import CNN
import torch.nn as nn
from tqdm import tqdm
def test(model, criterion, dataloader_test, dataset_sizes_test):
    score = 0
    runing_loss = 0.0
    model.eval()
    

    with torch.no_grad():
        if rank != 0:
            print("Start proccess number ",rank)
            for image, label in tqdm(dataloader_test): 
                output = model(image)
                comm.send(output, dest=0, tag=0) 
                if rank == 1:
                    comm.send(label, dest=0, tag=1)
                _, preds = torch.max(output, 1)
                loss = criterion(output, label)
                runing_loss += loss.item() * image.size(0)
                score += torch.sum(preds == label.data)
            epoch_acc = score.double() / dataset_sizes_test
            runing_loss = runing_loss / dataset_sizes_test
            print("Test process ", rank, ": score: [", epoch_acc.item(), "], loss: [", runing_loss, "]")
        
        MPI.Comm.Barrier(MPI.COMM_WORLD)
        result = 0
        if rank == 0:
            print("Start proccess number ",rank)
            for i in tqdm(range(len(dataloader_test))):
                label = comm.recv(source=1, tag=1) 
                for procid in range(1, p):
                    output = comm.recv(source=procid, tag=0)
                    if procid == 1:
                        result_all_models = output
                    else:
                        result_all_models += output
                _, preds = torch.max(result_all_models, 1)
                result += torch.sum(preds == label.data)
            result = result.double() / dataset_sizes_test
            print("Test process result", rank, result.item())

if __name__ == "__main__":

  comm = MPI.COMM_WORLD
  rank = comm.Get_rank()
  p = comm.Get_size()
  num_epochs = 3
  batch_size = 30
  num_workers = 4
  train_dataloader = DataLoader(batch_size,num_workers)['train']
  validate_dataloader = DataLoader(batch_size,num_workers)['validate']
  model = CNN()
  if(rank != 0):
     model.load_state_dict(torch.load(f'/content/weights/model_{rank}.pth'))
  optimizer = optim.Adam(model.parameters(), lr=0.001)
  criterion = nn.CrossEntropyLoss()
  model = test(model,criterion,validate_dataloader,len(validate_dataloader) * batch_size)
  MPI.Finalize    

Overwriting Test.py


In [84]:
! mpirun --allow-run-as-root -np 4 python Test.py

Start proccess number  1
  4%|▍         | 14/334 [00:00<00:16, 19.44it/s]Start proccess number  2
  4%|▍         | 14/334 [00:00<00:14, 21.79it/s]Start proccess number  3
100%|██████████| 334/334 [00:14<00:00, 23.55it/s]Test process  2 : score: [ 0.9755489021956087 ], loss: [ 0.06615605563419813 ]
100%|██████████| 334/334 [00:16<00:00, 19.74it/s]Test process  1 : score: [ 0.981936127744511 ], loss: [ 0.04612626874586567 ]
100%|██████████| 334/334 [00:15<00:00, 20.93it/s]Test process  3 : score: [ 0.9760479041916168 ], loss: [ 0.06548089371242899 ]


Start proccess number  0

100%|██████████| 334/334 [00:00<00:00, 466.45it/s]Test process result 0 0.41097804391217563

