In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Subset, DataLoader, TensorDataset

import torchvision
import torchvision.transforms as transforms

import time
import copy
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

In [2]:
from accelerate import Accelerator

In [14]:
device_count = torch.cuda.device_count()
print(f"Device count: {device_count}")

for i in range(device_count):
    print(torch.cuda.get_device_properties(i))

Device count: 2
_CudaDeviceProperties(name='Z100SM', major=7, minor=5, total_memory=16368MB, multi_processor_count=64)
_CudaDeviceProperties(name='Z100SM', major=7, minor=5, total_memory=16368MB, multi_processor_count=64)


In [15]:
computing_device = "cuda"
if not torch.cuda.is_available():
    computing_device = "cpu"
print("Computing Device: ", computing_device)

Computing Device:  cuda


In [6]:
accelerator = Accelerator()
# print(accelerator)

In [16]:
batch_size = 256

# transform = transforms.Compose([transforms.ToTensor(),
#                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # 先四周填充0，在吧图像随机裁剪成32*32
    transforms.RandomHorizontalFlip(),  # 图像一半的概率翻转，一半的概率不翻转
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),  # R,G,B每层的归一化用到的均值和方差
])

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train)
validate_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_val)

trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
valloader = DataLoader(validate_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

print(len(trainloader.dataset))
print(len(valloader.dataset))

50000
10000


In [17]:
# print(torchvision.models())
resnet = torchvision.models.resnet101(num_classes=10).to(computing_device)


# resnet = torchvision.models.resnet101(num_classes=10)


if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(resnet)

lr = 0.001
lossfn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet.parameters(), lr=lr)


# resnet, optimizer, trainloader, valloader = accelerator.prepare(model, optimizer, trainloader, valloader)


Using 2 GPUs!


In [18]:
print(sum(p.numel() for p in resnet.parameters()))   # 查看一下模型参数量

42520650


In [19]:
# print(model)
# for X, y in valloader:
#     print(X.size(), y.size())
#     break

# 测试一下模型的输入输出维度
tmpdata = torch.randn([128, 3, 32, 32]).to(computing_device)
tmpout = resnet(tmpdata)

print(tmpout.size())
print(tmpout)


torch.Size([128, 10])
tensor([[ 0.1384,  0.8438, -0.5476,  ...,  0.3943, -0.7210, -0.5561],
        [-0.6230,  0.6988, -1.2198,  ...,  1.6391, -1.3147, -0.5729],
        [-0.6027, -0.4884, -1.2437,  ...,  0.8441,  1.6953,  0.0948],
        ...,
        [-0.5671,  0.5456, -0.8684,  ...,  0.7286, -1.3601, -0.3616],
        [-0.2223,  0.5276, -2.0907,  ...,  0.9317, -1.5681,  0.1793],
        [ 0.3822,  0.0922, -1.8689,  ..., -0.0190,  0.8040, -0.7867]],
       device='cuda:0', grad_fn=<AddmmBackward0>)


In [None]:
for k in resnet.state_dict():
    if k.find("bias") >= 1:
        continue
    print(k)

In [20]:
# 训练
def train_model(epoch, model, loss_fn, optimizer, trainloader):
    # training
    num_batches = len(trainloader)
    model.train()
    train_loss = 0
    for batch, (X, y) in enumerate(trainloader):
        X, y = X.to(computing_device), y.to(computing_device)
        optimizer.zero_grad()
        
        predict = model(X)

        loss = loss_fn(predict, y)
        train_loss += loss.item()

        loss.backward()
        # accelerator.backward(loss)
        optimizer.step()
        
    train_loss /= num_batches

    return train_loss

# 验证
def val_model(epoch, model, loss_fn, valloader):
    size = len(valloader.dataset)
    num_batches = len(valloader)
    
    model.eval()
    test_loss, val_correct = 0, 0
    with torch.no_grad():        
        for batch, (X, y) in enumerate(valloader):
            # X, y = X.to(computing_device), y.to(computing_device)

            predict = model(X)
            loss = loss_fn(predict, y)
            test_loss += loss.item()
            val_correct += (predict.argmax(1) == y).type(torch.float).sum().item() 
            
    test_loss /= num_batches
    val_correct /= size

    return test_loss, val_correct

# 测试
def test_model(model, loss_fn, testloader):
    pass


In [21]:
start_epoch = 0 # 从哪一个epoch开始
num_epochs = 5 # 要训练多少个epoch

# time_all = 0    # 消耗的总时长，单位s

for epoch in range(start_epoch, start_epoch + num_epochs):    
    ts = time.perf_counter() # 打一个时间戳

    train_loss = train_model(epoch, resnet, lossfn, optimizer, trainloader)
    td = time.perf_counter()    # 打一个时间戳
    
    # val_loss, val_correct = val_model(epoch, resnet, lossfn, valloader)
    
    # writer.add_scalar("Loss/train", train_loss, epoch)
    
    # print(f"Epoch {epoch} | TrainLoss {train_loss:.5f} | ValLoss {val_loss:.5f} | ValCorrect {val_correct:.5f}| TrainTime {(td - ts):.5f}s")
    print(f"Epoch {epoch} | TrainLoss {train_loss:.5f} | TrainTime {(td - ts):.5f}s")

    # print(f"Epoch: {epoch} | TrainLoss: {train_loss:.5f} | ValLoss: {test_loss:.5f} | EpochTime: {(td - ts):.5f}s ｜ TimeRemaining: {(time_all / (epoch - start_epoch + 1)) * (start_epoch + num_epochs - epoch - 1):.5f}s")
    print("----- ----- ----- ----- -----")


Epoch 0 | TrainLoss 2.02791 | TrainTime 27.59285s
----- ----- ----- ----- -----
Epoch 1 | TrainLoss 1.72336 | TrainTime 27.70242s
----- ----- ----- ----- -----
Epoch 2 | TrainLoss 1.59175 | TrainTime 27.71325s
----- ----- ----- ----- -----
Epoch 3 | TrainLoss 1.61943 | TrainTime 26.96696s
----- ----- ----- ----- -----
Epoch 4 | TrainLoss 1.39395 | TrainTime 27.33436s
----- ----- ----- ----- -----
