In [None]:
#!/usr/bin/env python3

from pathlib import Path
import torch
import torchvision
import datasets
import threading


# super(MultiHeadedAttention, self).__init__()


class MyModel(torch.nn.Module):
    def __init__(self, inner_model):
        super(MyModel, self).__init__()
        self.inner_model = inner_model
    # end
    
    def forward(self, *args, **kwargs):
        # print('[{}] run model start, cuda: {}'.format(threading.get_native_id(), torch.cuda.current_device()))
        output = self.inner_model(*args, **kwargs)
        print('[{}] run model, cuda: [{}], output: [{}], shape: {}'.format(threading.get_native_id(), torch.cuda.current_device(), output.device, output.shape))
        return output
    # end
# end


def load_data(num_gpus):
    transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                                                 ])
    dataset = torchvision.datasets.ImageFolder(root='./PetImages', transform=transforms)
    # dataset = datasets.load_dataset('mnist', cache_dir='.')['train']

    dataloader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=64,
        shuffle=False,
        num_workers=4*num_gpus
                                                )
    return dataloader


def run_training(num_gpus):

    model = torchvision.models.resnet50(pretrained=False)
    model = MyModel(model)
    model = model.cuda()
    model = torch.nn.parallel.DataParallel(model, device_ids=list(range(num_gpus)), dim=0)

    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
    criterion = torch.nn.CrossEntropyLoss()
    criterion.cuda()
    model.train()
    num_epochs = 30
    dataloader = load_data(num_gpus)
    total_steps = len(dataloader)
    for epoch in range(1, num_epochs):
        print(f'\nEpoch {epoch}\n')
        
        # if epoch > 1:
        #     model, optimizer = load_model(epoch-1, model, optimizer)
        # # end
        
        for step, (images, labels) in enumerate(dataloader, 1):
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            # print('[{}] run model, cuda: [{}]'.format(threading.get_native_id(), torch.cuda.current_device()))
            outputs = model(images)
            print('[{}] cuda: [{}], output: {}, shape: {}, label: {}, shape: {}'.format(threading.get_native_id(), outputs.device, torch.cuda.current_device(), outputs.shape, labels.device, labels.shape))
            loss = criterion(outputs, labels)
            print('[{}] calculate loss, cuda: [{}]'.format(threading.get_native_id(), torch.cuda.current_device()))
            loss.backward()
            optimizer.step()
            if step % 10 == 0:
                print(f'Epoch [{epoch} / {num_epochs}], Step [{step} / {total_steps}], Loss: {loss.item():.4f}')
            # end
        # save_model(epoch, model, optimizer)

if __name__ == "__main__":
    num_gpus = torch.cuda.device_count()
    print('num_gpus: ', num_gpus)
    # torch.multiprocessing.set_start_method('spawn')
    run_training(num_gpus)