In [1]:
import torch
from torch import nn
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from tqdm import tqdm
import numpy as np

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])
train_dataset = MNIST('./datasets', train=True, download=True, transform=transform)

val_dataset = MNIST('./datasets', train=False, download=True, transform=transform)
np.random.seed(2023)
valset = torch.utils.data.Subset(val_dataset, np.random.randint(0,
                                                         len(val_dataset),
                                                         5000))
val_loader = torch.utils.data.DataLoader(valset, batch_size=256, shuffle=True)

In [3]:
from mnist_cnn import MNIST_CNN

In [4]:
len(train_dataset)

60000

In [None]:
model = MNIST_CNN().cuda()

In [None]:
epochs = 30
num_trained_models = 500
batchsize = 64
num_data_per_model = 1000

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
from train_val import mnist_validation

In [None]:
for m in range(num_trained_models):
    trainset = torch.utils.data.Subset(train_dataset,
                                       np.random.randint(0,
                                                         len(train_dataset),
                                                         num_data_per_model))
    dataloader = torch.utils.data.DataLoader(trainset,
                                             batch_size=batchsize,
                                             shuffle=True)

    model.reset_parameters(2023)
    model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    best_val_acc = 0.

    pbar = tqdm(range(epochs))
    pbar.set_description(f'{m} th model')
    for epoch in pbar:  # loop over the dataset multiple times
        counts = 0
        corrects = 0
        for i, data in enumerate(dataloader):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to('cuda:0')
            labels = labels.to('cuda:0')

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            counts += inputs.shape[0]
            corrects += (outputs.argmax(1) == labels).sum()
            
#         model.eval()
#         acc = mnist_validation(val_loader, model)
        
#         if acc > best_val_acc:
#             best_val_acc = acc
#             torch.save(model.state_dict(), f'pretrained_models/model{m}.pt')

        # pbar.set_postfix({'acc': corrects / counts, 'val acc': acc})
    
    torch.save(model.state_dict(), f'pretrained_models/model{m}.pt')