In [4]:
import torch
import timm
import tqdm
import sys
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from datetime import datetime 
import numpy as np

"""
CIFAR10:
- 32x32 colour image
- 60000 training
- 10000 testing
- ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] classes
"""

#手动定义参数类，专门存储训练的参数，
class cfg(object):
  pass

class args(object):
  pass

args.epochs = 40
args.batch_size = 32
args.model = 'resnet50'
args.learning_rate = 0.01
args.path = '/Users/lixu/Desktop/llm-notes/neural_network/data'
cfg.valid_loss = np.array([])

model = timm.create_model(args.model, pretrained=True,num_classes=10)
device = torch.device("cpu")
model.to(device)
train_dataset = torchvision.datasets.CIFAR10(root = args.path,
                                              train = True,
                                                  transform = transforms.Compose([
#                                                           transforms.Resize((32,32)),
                                                          transforms.ToTensor(),
                                                          transforms.Normalize(mean = (0.1307,), std = (0.3081,))]),
                                                  download = True)

# train_dataset, val_dataset = torch.utils.data.random_split(row_dataset, [40000, 10000])


test_dataset = torchvision.datasets.CIFAR10(root = args.path,
                                                  train = False,
                                                  transform = transforms.Compose([
#                                                           transforms.Resize((32,32)),
                                                  transforms.ToTensor(),
                                                  transforms.Normalize(mean = (0.1325,), std = (0.3105,))]),)


train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                           batch_size = args.batch_size,
                                           shuffle = True)


test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                          batch_size = args.batch_size,
                                          shuffle = True)


model.safetensors:   0%|          | 0.00/102M [00:00<?, ?B/s]

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /Users/lixu/Desktop/llm-notes/neural_network/data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:08<00:00, 20547048.25it/s]


Extracting /Users/lixu/Desktop/llm-notes/neural_network/data/cifar-10-python.tar.gz to /Users/lixu/Desktop/llm-notes/neural_network/data


In [None]:
# plot the validation loss
cfg.valid_loss = np.array([])
def display_loss(cfg, data):
    if len(cfg.valid_loss) == 0:
        cfg.valid_loss = np.array([data])
        x, y = cfg.valid_loss.T
        plt.scatter(x,y)
        plt.show()
    else:
        cfg.valid_loss = np.concatenate([cfg.valid_loss, np.array([data])])
        x, y = cfg.valid_loss.T
        plt.plot(x,y)
        plt.show()


In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),args.learning_rate)


# 0: loading existing model parameters
# model.load_state_dict(torch.load(args.path)
# model.eval()


print('start training')
total_step = len(train_loader)
avg_vloss = 0
best_vloss = 0
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
epoch_number = 0
best_vloss = 1_000_000.

for epoch in range(args.epochs):
    
    # 1. Per epoch
    for i, (images, labels) in enumerate(train_loader):  
        images = images.to(device)
        labels = labels.to(device)
        model.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                   .format(epoch+1, args.epochs, i+1, total_step, loss.item()))
    print('end one traning epoch')

    
    # 2. validation: Disable gradient computation and reduce memory consumption.
    avg_vloss = 0.0
    with torch.no_grad():
        correct = 0
        total = 0

        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            avg_vloss += loss.item()
    print('Training vs. Validation Loss',
            { 'Training' : best_vloss, 'Validation' : avg_vloss/len(test_loader) },
            epoch_number + 1)
    
    display_loss(cfg, [epoch_number, avg_vloss/len(test_loader)])

#             _, predicted = torch.max(outputs.data, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()
        
        
    # 3. Track best performance, and save the model's state
    avg_vloss = avg_vloss /len(test_loader) / (epoch_number + 1)
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = args.path + '/model_{}_{}'.format(timestamp, epoch_number)
        print("store at {}".format(model_path))
        torch.save(model.state_dict(), model_path)
        
    
    # 4. epoch loop, ref: https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
    epoch_number += 1

print('end training')

In [1]:
from sklearn import metrics

# These are just example labels
y_true = [0, 1, 1, 0, 1, 1]
y_pred = [0, 0, 1, 1, 1, 1]

# Accuracy
accuracy = metrics.accuracy_score(y_true, y_pred)
print(f'Accuracy: {accuracy}')

# Precision
precision = metrics.precision_score(y_true, y_pred)
print(f'Precision: {precision}')

# Recall
recall = metrics.recall_score(y_true, y_pred)
print(f'Recall: {recall}')

# F1 Score
f1 = metrics.f1_score(y_true, y_pred)
print(f'F1 Score: {f1}')

Accuracy: 0.6666666666666666
Precision: 0.75
Recall: 0.75
F1 Score: 0.75


In [5]:
with torch.no_grad():
    correct = 0
    total = 0

    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
    
        print(outputs[0])
        print(labels)
        break
        

tensor([-0.0811, -0.0645,  0.0974, -0.0639,  0.3708, -0.1186, -0.0487, -0.3961,
         0.0347,  0.0164])
tensor([5, 4, 8, 8, 7, 6, 1, 6, 7, 5, 5, 8, 8, 9, 7, 5, 3, 9, 7, 4, 2, 0, 8, 0,
        9, 0, 6, 3, 7, 2, 8, 7])
