In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
import torchvision.datasets as datasets

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
epochs = 70

transform_train = transforms.Compose([
    transforms.TrivialAugmentWide(),
    transforms.ToTensor(),
])

transform = transforms.Compose([
    transforms.ToTensor(),
])

cifar10_train = datasets.CIFAR10(root='root for CIFAR-10 dataset', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR10(root='root for CIFAR-10 dataset', train=False, download=True, transform=transform)

trainloader = DataLoader(cifar10_train, batch_size=128, shuffle=True, num_workers=8)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=8)

from tqdm import tqdm
import sys

tra_num = len(cifar10_train)
val_num = len(testset)

def train(model):
    criterion = torch.nn.CrossEntropyLoss()
    lr = 1e-3
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=70, eta_min=1e-6, verbose=True)

    train_steps = len(trainloader)
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        train_bar = tqdm(trainloader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits = model(images.to(device))
            loss = criterion(logits, labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
        
        scheduler.step()
            
        model.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(testloader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = model(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)
                
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.4f' % (epoch + 1, running_loss / train_steps, val_accurate))
    
    # print test acc
    acc = 0.0
    with torch.no_grad():
        val_bar = tqdm(testloader, file=sys.stdout)
        for val_data in val_bar:
            val_images, val_labels = val_data
            outputs = model(val_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

    val_accurate = acc / val_num
    print(val_accurate)

In [None]:
import torchvision
import torch.nn as nn
from torchvision.models.resnet import resnet18, ResNet18_Weights
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.maxpool = nn.Identity()
model.fc = nn.Linear(512,10)
model.to(device)
print('model prepared.')

In [None]:
train(model)

In [None]:
torch.save(model, 'root to save the model')