In [1]:
import jittor as jt
from jittor import Module
from jittor import nn
from jittor.dataset.mnist import MNIST
import jittor.transform as trans
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

matplotlib.use('TkAgg')

jt.flags.use_cuda = 1

In [2]:
batch_size = 64
train_loader = MNIST(train=True, transform=trans.Resize(28)).set_attrs(batch_size=batch_size, shuffle=True)
val_loader = MNIST(train=False, transform=trans.Resize(28)).set_attrs(batch_size=batch_size, shuffle=False)

In [4]:
num = 0
for inputs, targets in val_loader:
    print("inputs.shape: ", inputs.shape)
    print("targets.shape: ", targets.shape)
    plt.imshow(inputs[num].numpy().transpose(1, 2, 0))
    print("target: ", targets[num].data[0])
    plt.show()
    break

inputs.shape:  [64,3,28,28,]
targets.shape:  [64,]
target:  7


In [5]:
class Model(Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv(3, 32, 3, 1)
        self.conv2 = nn.Conv(32, 64, 3, 1)
        self.bn = nn.BatchNorm(64)
        self.max_pool = nn.Pool(2, 2)
        self.relu = nn.Relu()
        self.fc1 = nn.Linear(64 * 12 * 12, 256)
        self.fc2 = nn.Linear(256, 10)

    def execute(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.max_pool(x)
        x = jt.reshape(x, [x.shape[0], -1])
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [7]:
model = Model()

loss_func = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

In [8]:
def train(model, train_loader, loss_function, optimizer, epoch):
    model.train()
    train_losses = list()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        outputs = model(inputs)
        loss = loss_function(outputs, targets)
        optimizer.step(loss)
        train_losses.append(loss)

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx, len(train_loader), 100. * batch_idx / len(train_loader), loss.data[0]
            ))
    return train_losses


def test(model, val_loader, loss_function, epoch):
    model.eval()
    total_correct = 0
    total_num = 0
    for batch_idx, (inputs, targets) in enumerate(val_loader):
        outputs = model(inputs)
        pred = np.argmax(outputs.data, axis=1)
        correct = np.sum(targets.data == pred)

        total_correct += correct
        total_num += inputs.shape[0]
    test_acc = total_correct / total_num
    print("Test Accuracy: ", test_acc)
    return test_acc

In [12]:
epochs = 5
train_losses = list()
test_acc = list()
for epoch in range(epochs):
    loss = train(model, train_loader, loss_func, optimizer, epoch)
    acc = test(model, val_loader, loss_func, epoch)
    train_losses += loss
    test_acc.append(acc)

Test Accuracy:  0.9919
Test Accuracy:  0.9916
Test Accuracy:  0.9915
Test Accuracy:  0.9922
Test Accuracy:  0.9924


In [10]:
plt.plot(train_losses, label="Train Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

  ary = asanyarray(ary)
  ary = asanyarray(ary)


In [13]:
plt.plot(test_acc, label="Test Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

In [17]:
num = 6
for inputs, targets in val_loader:
    plt.imshow(inputs[num].numpy().transpose(1, 2, 0))
    plt.show()

    print("target: ", targets[num].data[0])

    outputs = model(inputs)
    pred = np.argmax(outputs.data, axis=1)
    print("prediction: ", pred[num])

    break

target:  4
prediction:  4
