In [80]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from tqdm import tqdm

In [81]:
class Model(torch.nn.Module):
  def __init__(self):
    super().__init__()

    self.conv1 = nn.Conv2d(3, 32, (3, 3), (1, 1), (1, 1)) 
    self.conv2 = nn.Conv2d(32, 64, (3, 3), (1, 1), (1, 1))
    self.conv3 = nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1))


    self.fc1 = nn.Linear(128*8*8 ,128)
    self.fc2 = nn.Linear(128, 10)

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.max_pool2d(x, kernel_size=(2, 2))

    x = F.relu(self.conv2(x))
    x = F.max_pool2d(x, kernel_size=(2, 2))

    x = F.relu(self.conv3(x))
    x = F.max_pool2d(x, kernel_size=(2, 2))
    x = torch.flatten(x, start_dim=1)

    x = self.fc1(x)
    x = self.fc2(x)
    output = torch.softmax(x, dim=1)

    return output
      

In [82]:
!pip install wandb



In [83]:
import wandb
wandb.init(project="Persian Mnist torch")

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

In [84]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Model().to(device)
# wandb.watch(model)

In [85]:
batch_size = 64
epochs = 20
config = wandb.config
config.learning_rate = 0.001


In [86]:
from torchvision.transforms.transforms import Resize
transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.Resize((70, 70)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

dataset = torchvision.datasets.ImageFolder(root="/content/drive/MyDrive/MNIST_persian", transform = transform)
train_data = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [87]:
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
loss_function = nn.CrossEntropyLoss()


In [88]:

def calc_acc(preds, labels):
  _, pred_max = torch.max(preds, 1)
  acc = torch.sum(pred_max == labels.data, dtype=torch.float64) / len(preds)
  return acc

In [89]:
model.train()

for epoch in range(epochs):
  train_loss = 0.0
  train_acc = 0.0
  for images, labels in tqdm(train_data):

    images, labels = images.to(device), labels.to(device)

    optimizer.zero_grad()
    preds = model(images)

    loss = loss_function(preds, labels)
    loss.backward()

    optimizer.step()

    train_loss += loss
    train_acc += calc_acc(preds, labels)

  total_loss = train_loss / len(train_data)
  total_acc = train_acc / len(train_data)
  print(f"Epoch: {epoch+1}, Loss: {total_loss}, Accuracy: {total_acc}")

  wandb.log({'epochs':  epoch + 1,'loss': total_loss,'acc': total_acc})

100%|██████████| 19/19 [00:03<00:00,  5.04it/s]


Epoch: 1, Loss: 2.304955244064331, Accuracy: 0.1156798245614035


100%|██████████| 19/19 [00:05<00:00,  3.39it/s]


Epoch: 2, Loss: 2.0693411827087402, Accuracy: 0.42790570175438597


100%|██████████| 19/19 [00:05<00:00,  3.73it/s]


Epoch: 3, Loss: 1.8796762228012085, Accuracy: 0.5868969298245613


100%|██████████| 19/19 [00:04<00:00,  4.34it/s]


Epoch: 4, Loss: 1.8454822301864624, Accuracy: 0.6175986842105263


100%|██████████| 19/19 [00:03<00:00,  5.11it/s]


Epoch: 5, Loss: 1.8330790996551514, Accuracy: 0.6285635964912281


100%|██████████| 19/19 [00:03<00:00,  5.71it/s]


Epoch: 6, Loss: 1.7775499820709229, Accuracy: 0.6858552631578947


100%|██████████| 19/19 [00:03<00:00,  5.78it/s]


Epoch: 7, Loss: 1.7504503726959229, Accuracy: 0.709703947368421


100%|██████████| 19/19 [00:03<00:00,  5.81it/s]


Epoch: 8, Loss: 1.7191352844238281, Accuracy: 0.7425986842105263


100%|██████████| 19/19 [00:03<00:00,  5.75it/s]


Epoch: 9, Loss: 1.6712735891342163, Accuracy: 0.7897478070175439


100%|██████████| 19/19 [00:03<00:00,  5.68it/s]


Epoch: 10, Loss: 1.648241400718689, Accuracy: 0.8160635964912281


100%|██████████| 19/19 [00:03<00:00,  5.64it/s]


Epoch: 11, Loss: 1.6360490322113037, Accuracy: 0.828673245614035


100%|██████████| 19/19 [00:03<00:00,  5.73it/s]


Epoch: 12, Loss: 1.6160825490951538, Accuracy: 0.8500548245614036


100%|██████████| 19/19 [00:03<00:00,  5.75it/s]


Epoch: 13, Loss: 1.5875787734985352, Accuracy: 0.8771929824561404


100%|██████████| 19/19 [00:03<00:00,  5.79it/s]


Epoch: 14, Loss: 1.563018798828125, Accuracy: 0.9032346491228069


100%|██████████| 19/19 [00:03<00:00,  5.23it/s]


Epoch: 15, Loss: 1.557346224784851, Accuracy: 0.9073464912280701


100%|██████████| 19/19 [00:03<00:00,  4.86it/s]


Epoch: 16, Loss: 1.5349829196929932, Accuracy: 0.9268092105263157


100%|██████████| 19/19 [00:03<00:00,  5.23it/s]


Epoch: 17, Loss: 1.5166478157043457, Accuracy: 0.9484649122807016


100%|██████████| 19/19 [00:03<00:00,  5.75it/s]


Epoch: 18, Loss: 1.525356650352478, Accuracy: 0.9385964912280701


100%|██████████| 19/19 [00:03<00:00,  5.77it/s]


Epoch: 19, Loss: 1.5113000869750977, Accuracy: 0.9575109649122806


100%|██████████| 19/19 [00:03<00:00,  5.51it/s]

Epoch: 20, Loss: 1.5002306699752808, Accuracy: 0.9654605263157894





In [90]:
torch.save(model.state_dict(), "mnist_p.pth")