In [1]:
# TODO: this file is just for debugging after everything works correct clean the repo
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.optim import SGD

In [2]:
class PiNetSequential(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.fc1_1 = nn.Linear(784, 128)
        self.fc1_2 = nn.Linear(784, 128)

        self.fc2_1 = nn.Linear(128, 64)
        self.fc2_2 = nn.Linear(128, 64)

        self.fc3 = nn.Linear(64, 10)

    def forward(self, input_data):
        out_1_1 = self.fc1_1(input_data)
        out_1_2 = self.fc1_2(input_data)
        out_1 = out_1_1 * out_1_2 + out_1_1
        out_2_1 = self.fc2_1(out_1)
        out_2_2 = self.fc2_2(out_1)
        out_2 = out_2_1 * out_2_2 + out_2_1
        out_3 = self.fc3(out_2)
        return out_3

In [3]:
batch_size = 128
num_of_epochs = 1
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

target_transform = transforms.Lambda(lambda y: torch.zeros(10, dtype=torch.float)
                                     .scatter_(0, torch.tensor(y), 1))

# load data
train_dataset = MNIST('../../data', train=True, transform=transform, target_transform=target_transform,
                      download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = MNIST('../../data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

last_batch_idx = int(len(train_dataset) / batch_size)
if len(train_dataset) % batch_size != 0:
    last_batch_idx = last_batch_idx + 1

model = PiNetSequential().to(device)
optimizer = SGD(model.parameters(), lr=0.1)
criterion = nn.MSELoss()
running_loss = []
running_acc = []
running_curr_loss = []
for epoch in range(num_of_epochs):
    curr_loss = torch.zeros(1).to(device)
    for idx, (data, label) in enumerate(train_loader):
        model.train()
        data = data.reshape(data.size(0), -1).to(device)
        label = label.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(label, out)
        loss.backward()
        optimizer.step()
        curr_loss += loss * 10
        running_curr_loss.append(loss.item() * 10)

        with torch.no_grad():
            if idx == 0 or (idx + 1) % 10 == 0 or (idx + 1) == last_batch_idx:
                curr_acc = 0
                model.eval()
                if idx == 0:
                    running_loss.append(curr_loss.item())
                elif (idx + 1) == last_batch_idx:
                    running_loss.append((curr_loss / ((idx + 1) % 10)).item())
                else:
                    running_loss.append((curr_loss / 10).item())
                test_total = 0
                for test_data, test_label in test_loader:
                    test_data = test_data.reshape(test_data.size(0), -1).to(device)
                    test_label = test_label.to(device)
                    test_out = model(test_data)
                    pred_label = torch.argmax(test_out, dim=1)
                    # noinspection PyTypeChecker
                    curr_acc = curr_acc + torch.count_nonzero(pred_label == test_label)
                    test_total = test_total + test_data.size(0)
                running_acc.append(curr_acc / test_total)
                if idx == 0 or (idx + 1) % 10 == 0:
                    print('epoch: {}, loss: {}, acc: {}'.format(epoch, running_loss[-1], running_acc[-1]))
                curr_loss = torch.zeros(1).to(device)

epoch: 0, loss: 1.6517184972763062, acc: 0.1462000012397766
epoch: 0, loss: 0.8333964347839355, acc: 0.5586000084877014
epoch: 0, loss: 0.6889950037002563, acc: 0.693399965763092
epoch: 0, loss: 0.5761732459068298, acc: 0.7583999633789062
epoch: 0, loss: 0.5271390080451965, acc: 0.7883999943733215
epoch: 0, loss: 0.49322643876075745, acc: 0.8187999725341797
epoch: 0, loss: 0.46784907579421997, acc: 0.8305999636650085
epoch: 0, loss: 0.4393194615840912, acc: 0.8450999855995178
epoch: 0, loss: 0.4344979226589203, acc: 0.8488999605178833
epoch: 0, loss: 0.4274926781654358, acc: 0.8669999837875366
epoch: 0, loss: 0.39989951252937317, acc: 0.8669999837875366
epoch: 0, loss: 0.3819388449192047, acc: 0.8733999729156494
epoch: 0, loss: 0.37089622020721436, acc: 0.8764999508857727
epoch: 0, loss: 0.36728784441947937, acc: 0.8842999935150146
epoch: 0, loss: 0.3459676504135132, acc: 0.8868999481201172
epoch: 0, loss: 0.3498058617115021, acc: 0.8924999833106995
epoch: 0, loss: 0.35417428612709045,