# Imports

In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
import torchvision.transforms as T
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split
from skimage.color import gray2rgb

import numpy as np
from tqdm import tqdm
from utils.train import train

# Setup

In [2]:
epochs = 5
batch_size = 128
num_classes = 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using {device}.")

Using cpu.


In [3]:
weights = EfficientNet_B0_Weights.DEFAULT

model = efficientnet_b0(weights=weights)
in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features,num_classes)

preprocess = weights.transforms()
opt = Adam(model.parameters(), lr=1e-3)

In [4]:

transforms = T.Compose([
    gray2rgb,
    T.ToTensor()
])
labelled, unlabelled = train_test_split(
    MNIST("/Tera/datasets/MNIST/", download=True, transform=transforms),
    test_size=.2,
    train_size=.8,
    shuffle=True
)

dataloader = DataLoader(
    labelled,
    batch_size=batch_size,
    shuffle=True,
)

criterion = nn.CrossEntropyLoss()

# Learn over annotated dataset

In [5]:
train(
    model=model,
    epochs=5,
    criterion=criterion,
    dataloader=dataloader,
    opt=opt,
    metrics=[accuracy_score],
    device=device
)

100%|██████████| 375/375 [09:50<00:00,  1.58s/it]


[Epoch 0/5] Loss: -28252.76505 accuracy_score: 0.7265625 


100%|██████████| 375/375 [09:51<00:00,  1.58s/it]


[Epoch 1/5] Loss: -115743.53169 accuracy_score: 0.8203125 


100%|██████████| 375/375 [09:48<00:00,  1.57s/it]


[Epoch 2/5] Loss: -250477.25992 accuracy_score: 0.828125 


 11%|█         | 41/375 [01:07<09:08,  1.64s/it]


KeyboardInterrupt: 

In [None]:
test_loader = DataLoader(
    unlabelled,
    batch_size=batch_size,
    shuffle=True,
)

model.eval()
unlabelled = tqdm(unlabelled)
for x,y in unlabelled:
    x,y = x.to(device),y.to(device)
    y_hat = model(x).argmax().detach().numpy()

    acc = accuracy_score(y, y_hat)
    unlabelled.set_description(f"Accuracy: {acc*100:.2f}%")


# Learn using pseudo-labels