In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import tqdm
import lightning.pytorch as pl
from pytorch_lightning.loggers import WandbLogger
import wandb
from torchmetrics.image.fid import FrechetInceptionDistance


In [2]:
train_set = torchvision.datasets.QMNIST(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)

test_set = torchvision.datasets.QMNIST(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_set, batch_size=32, shuffle=False)


In [3]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 10)
        
    def forward(self, x):
        x = x.view (-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [26]:
from ema_pytorch import EMA

# your neural network as a pytorch module

net = Model()

# wrap your neural network, specify the decay (beta)

ema = EMA(
    net,
    beta = 0.9999,              # exponential moving average factor
    update_after_step = 100,    # only after this number of .update() calls will it start updating
    update_every = 10,          # how often to actually update, to save on compute (updates every 10th .update() call)
)

In [27]:
criterion = F.cross_entropy
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

In [30]:
for epoch in range(50):
    with tqdm.tqdm(train_loader, unit="batch") as tepoch:
        for data, target in tepoch:
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = net(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            tepoch.set_postfix_str(f"Loss: {loss.item()}")
            
            ema.update()
        

100%|██████████| 1875/1875 [00:16<00:00, 114.92batch/s, Loss: 0.00989589560776949]   
100%|██████████| 1875/1875 [00:14<00:00, 132.45batch/s, Loss: 0.03210417553782463]   
100%|██████████| 1875/1875 [00:14<00:00, 132.20batch/s, Loss: 0.068067766726017]     
100%|██████████| 1875/1875 [00:14<00:00, 133.12batch/s, Loss: 0.0073813507333397865] 
100%|██████████| 1875/1875 [00:14<00:00, 133.06batch/s, Loss: 0.0003683572285808623] 
100%|██████████| 1875/1875 [00:14<00:00, 131.57batch/s, Loss: 0.0007042007055133581] 
100%|██████████| 1875/1875 [00:14<00:00, 131.54batch/s, Loss: 0.002568792086094618]  
100%|██████████| 1875/1875 [00:14<00:00, 132.76batch/s, Loss: 0.0014360318891704082] 
100%|██████████| 1875/1875 [00:14<00:00, 131.58batch/s, Loss: 0.00039749627467244864]
100%|██████████| 1875/1875 [00:16<00:00, 110.32batch/s, Loss: 0.00026633558445610106]
100%|██████████| 1875/1875 [00:14<00:00, 130.64batch/s, Loss: 0.004170285537838936]  
100%|██████████| 1875/1875 [00:14<00:00, 131.55batch/s

KeyboardInterrupt: 

In [31]:
device = 'cpu'

correct = 0
total = 0
with torch.no_grad():
    with tqdm.tqdm(test_loader, unit="batch") as pbar:
        for data, target in pbar:
            data, target = data.to(device), target.to(device)
            
            outputs = net(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

            pbar.set_postfix_str(f"Accuracy: {correct / total}")


100%|██████████| 1875/1875 [00:11<00:00, 157.40batch/s, Accuracy: 0.9766666666666667]


In [24]:
print(f"Accuracy: {correct / total}")

Accuracy: 0.9735333333333334
