In [4]:
import os
import cv2
import numpy as np
from tqdm import tqdm

training_data = np.load("training_data.npy", allow_pickle = True)

In [5]:
print(len(training_data))

24946


In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.conv3 = nn.Conv2d(64, 128, 5)
        
        x = torch.randn(50,50).view(-1,1,50,50)
        self._to_linear = None
        self.convs(x)   
        
        self.fc1 = nn.Linear(self._to_linear, 512)
        self.fc2 = nn.Linear(512,2)
        
    def convs(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)),(2,2))
        x = F.max_pool2d(F.relu(self.conv3(x)),(2,2))
        
        if self._to_linear is None:
            print(x[0].shape)
            self._to_linear = x[0].shape[0] * x[0].shape[1] * x[0].shape[2]
        return x
    
    def forward(self,x):
        x = self.convs(x)
        x = x.view(-1, self._to_linear)
        x = F.relu(self.fc1(x))
        x = (self.fc2(x))
        return F.softmax(x, dim = 1)


net = Net()

torch.Size([128, 2, 2])


In [18]:
import torch.optim as optim

optimizer = optim.Adam(net.parameters(), lr = 0.001)
loss_function = nn.MSELoss()

X = torch.Tensor([i[0] for i in  training_data]).view(-1, 50, 50)
X = X/255.0
y = torch.Tensor([i[1] for i in training_data])

VAL_PCT = 0.1
val_size = int(len(X) * VAL_PCT)
print(val_size)

2494


In [19]:
train_X = X[:-val_size]
train_y = y[:-val_size]

test_X = X[-val_size:]
test_y = y[-val_size:]

print(len(train_X))
print(len(test_y))

22452
2494


In [24]:
BATCH_SIZE = 100
EPOCHS = 1

for epoch in range(EPOCHS):
    for i in tqdm(range(0, len(train_X), BATCH_SIZE)):
        batch_X = train_X[i:i + BATCH_SIZE].view(-1,1,50,50)
        batch_y = train_y[i:i + BATCH_SIZE]
        
        net.zero_grad()
        outputs = net(batch_X)
        loss = loss_function(outputs, batch_y)
        loss.backward()
        optimizer.step()
        
print(loss)


  0%|          | 0/225 [00:00<?, ?it/s][A
  0%|          | 1/225 [00:01<05:05,  1.37s/it][A
  1%|          | 2/225 [00:01<03:48,  1.02s/it][A
  1%|▏         | 3/225 [00:01<02:53,  1.28it/s][A
  2%|▏         | 4/225 [00:02<02:14,  1.64it/s][A
  2%|▏         | 5/225 [00:02<01:48,  2.03it/s][A
  3%|▎         | 6/225 [00:02<01:29,  2.45it/s][A
  3%|▎         | 7/225 [00:02<01:16,  2.86it/s][A
  4%|▎         | 8/225 [00:02<01:07,  3.23it/s][A
  4%|▍         | 9/225 [00:03<01:00,  3.59it/s][A
  4%|▍         | 10/225 [00:03<00:55,  3.85it/s][A
  5%|▍         | 11/225 [00:03<00:52,  4.05it/s][A
  5%|▌         | 12/225 [00:03<00:50,  4.21it/s][A
  6%|▌         | 13/225 [00:03<00:48,  4.36it/s][A
  6%|▌         | 14/225 [00:04<00:47,  4.44it/s][A
  7%|▋         | 15/225 [00:04<00:46,  4.51it/s][A
  7%|▋         | 16/225 [00:04<00:45,  4.57it/s][A
  8%|▊         | 17/225 [00:04<00:44,  4.66it/s][A
  8%|▊         | 18/225 [00:05<00:44,  4.60it/s][A
  8%|▊         | 19/225 [00:0

tensor(0.2038, grad_fn=<MseLossBackward>)


In [25]:
correct = 0
total = 0
with torch.no_grad():
    for i in tqdm(range(len(test_X))):
        real_class = torch.argmax(test_y[i])
        net_out = net(test_X[i].view(-1, 1, 50, 50))[0] 
        
        predicted_class = torch.argmax(net_out)
        if predicted_class == real_class:
            correct += 1
        total+=1

print("Accuracy:", round(correct/total, 3))
    


  0%|          | 0/2494 [00:00<?, ?it/s][A
  0%|          | 1/2494 [00:00<05:08,  8.08it/s][A
  2%|▏         | 57/2494 [00:00<03:32, 11.47it/s][A
  4%|▍         | 110/2494 [00:00<02:26, 16.23it/s][A
  6%|▌         | 155/2494 [00:00<01:42, 22.83it/s][A
  8%|▊         | 209/2494 [00:00<01:11, 32.02it/s][A
 10%|█         | 255/2494 [00:00<00:50, 44.42it/s][A
 12%|█▏        | 309/2494 [00:00<00:35, 61.29it/s][A
 14%|█▍        | 359/2494 [00:00<00:25, 83.15it/s][A
 16%|█▋        | 407/2494 [00:00<00:18, 110.43it/s][A
 18%|█▊        | 459/2494 [00:01<00:14, 144.54it/s][A
 21%|██        | 512/2494 [00:01<00:10, 184.53it/s][A
 23%|██▎       | 562/2494 [00:01<00:08, 222.54it/s][A
 25%|██▍       | 612/2494 [00:01<00:07, 266.98it/s][A
 27%|██▋       | 669/2494 [00:01<00:05, 316.44it/s][A
 29%|██▉       | 720/2494 [00:01<00:05, 348.52it/s][A
 31%|███       | 769/2494 [00:01<00:04, 376.97it/s][A
 33%|███▎      | 820/2494 [00:01<00:04, 408.25it/s][A
 35%|███▌      | 876/2494 [00:0

Accuracy: 0.613





In [28]:
net(test_X[i].view(-1, 1, 50, 50))

tensor([[0.3738, 0.6262]], grad_fn=<SoftmaxBackward>)

In [29]:
torch.cuda.is_available()

True

In [30]:
device = torch.device("cuda:0")

In [31]:
device

device(type='cuda', index=0)

In [33]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("running on cuda")
else: 
    device = torhc.device("cpu")
    print("running on cpu")

running on cuda


In [34]:
torch.cuda.device_count()

1

In [35]:
net.to(device)

Net(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv3): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=512, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=2, bias=True)
)

In [47]:
EPOCHS = 100
BATCH_SIZE = 100

def train(net):
    optimizer = optim.Adam(net.parameters(), lr = 0.001)
    loss_function = nn.MSELoss()
    for epoch in range(EPOCHS):
        for i in tqdm(range(0, len(train_X), BATCH_SIZE)):
            batch_X = train_X[i:i + BATCH_SIZE].view(-1,1,50,50)
            batch_y = train_y[i:i + BATCH_SIZE]
            batch_X  = batch_X.to(device)
            batch_y  = batch_y.to(device)
            net.zero_grad()
            outputs = net(batch_X)
            loss = loss_function(outputs, batch_y)
            loss.backward()
            optimizer.step()
#             print(f"Epoch: {epoch}. Loss: {loss}")
            
train(net)        
        


  0%|          | 0/225 [00:00<?, ?it/s][A
  3%|▎         | 7/225 [00:00<00:03, 66.21it/s][A
  7%|▋         | 15/225 [00:00<00:03, 68.43it/s][A
 10%|█         | 23/225 [00:00<00:02, 70.27it/s][A
 14%|█▍        | 32/225 [00:00<00:02, 74.16it/s][A
 18%|█▊        | 41/225 [00:00<00:02, 76.55it/s][A
 22%|██▏       | 50/225 [00:00<00:02, 79.42it/s][A
 26%|██▌       | 59/225 [00:00<00:02, 80.91it/s][A
 30%|███       | 68/225 [00:00<00:01, 81.88it/s][A
 34%|███▍      | 77/225 [00:00<00:01, 82.61it/s][A
 38%|███▊      | 86/225 [00:01<00:01, 82.89it/s][A
 42%|████▏     | 95/225 [00:01<00:01, 82.63it/s][A
 46%|████▌     | 104/225 [00:01<00:01, 82.91it/s][A
 50%|█████     | 113/225 [00:01<00:01, 83.10it/s][A
 54%|█████▍    | 122/225 [00:01<00:01, 84.17it/s][A
 58%|█████▊    | 131/225 [00:01<00:01, 84.69it/s][A
 62%|██████▏   | 140/225 [00:01<00:00, 85.06it/s][A
 66%|██████▌   | 149/225 [00:01<00:00, 84.13it/s][A
 70%|███████   | 158/225 [00:01<00:00, 85.15it/s][A
 74%|███████▍ 

 32%|███▏      | 72/225 [00:00<00:01, 86.52it/s][A
 36%|███▌      | 81/225 [00:00<00:01, 86.84it/s][A
 40%|████      | 90/225 [00:01<00:01, 87.32it/s][A
 44%|████▍     | 99/225 [00:01<00:01, 87.65it/s][A
 48%|████▊     | 108/225 [00:01<00:01, 88.00it/s][A
 52%|█████▏    | 117/225 [00:01<00:01, 87.48it/s][A
 56%|█████▌    | 126/225 [00:01<00:01, 87.12it/s][A
 60%|██████    | 135/225 [00:01<00:01, 87.38it/s][A
 64%|██████▍   | 144/225 [00:01<00:00, 87.19it/s][A
 68%|██████▊   | 153/225 [00:01<00:00, 86.58it/s][A
 72%|███████▏  | 162/225 [00:01<00:00, 87.38it/s][A
 76%|███████▌  | 171/225 [00:01<00:00, 87.70it/s][A
 80%|████████  | 180/225 [00:02<00:00, 87.68it/s][A
 84%|████████▍ | 189/225 [00:02<00:00, 87.15it/s][A
 88%|████████▊ | 198/225 [00:02<00:00, 87.54it/s][A
 92%|█████████▏| 207/225 [00:02<00:00, 87.56it/s][A
 96%|█████████▌| 216/225 [00:02<00:00, 86.56it/s][A
100%|██████████| 225/225 [00:02<00:00, 86.90it/s][A

  0%|          | 0/225 [00:00<?, ?it/s][A
  4%|▍

 52%|█████▏    | 118/225 [00:01<00:01, 84.10it/s][A
 56%|█████▋    | 127/225 [00:01<00:01, 83.92it/s][A
 60%|██████    | 136/225 [00:01<00:01, 85.34it/s][A
 64%|██████▍   | 145/225 [00:01<00:00, 85.28it/s][A
 68%|██████▊   | 154/225 [00:01<00:00, 84.51it/s][A
 72%|███████▏  | 163/225 [00:01<00:00, 84.23it/s][A
 76%|███████▋  | 172/225 [00:02<00:00, 84.26it/s][A
 80%|████████  | 181/225 [00:02<00:00, 83.58it/s][A
 84%|████████▍ | 190/225 [00:02<00:00, 83.57it/s][A
 88%|████████▊ | 199/225 [00:02<00:00, 83.80it/s][A
 92%|█████████▏| 208/225 [00:02<00:00, 83.96it/s][A
100%|██████████| 225/225 [00:02<00:00, 84.31it/s][A

  0%|          | 0/225 [00:00<?, ?it/s][A
  4%|▍         | 10/225 [00:00<00:02, 90.55it/s][A
  8%|▊         | 19/225 [00:00<00:02, 89.63it/s][A
 12%|█▏        | 28/225 [00:00<00:02, 88.49it/s][A
 16%|█▋        | 37/225 [00:00<00:02, 87.82it/s][A
 20%|██        | 46/225 [00:00<00:02, 86.12it/s][A
 24%|██▍       | 55/225 [00:00<00:02, 84.61it/s][A
 28%|██▊

 72%|███████▏  | 162/225 [00:01<00:00, 84.06it/s][A
 76%|███████▌  | 171/225 [00:02<00:00, 83.44it/s][A
 80%|████████  | 180/225 [00:02<00:00, 84.28it/s][A
 84%|████████▍ | 189/225 [00:02<00:00, 85.00it/s][A
 88%|████████▊ | 198/225 [00:02<00:00, 86.01it/s][A
 92%|█████████▏| 207/225 [00:02<00:00, 85.74it/s][A
 96%|█████████▌| 216/225 [00:02<00:00, 86.05it/s][A
100%|██████████| 225/225 [00:02<00:00, 84.75it/s][A

  0%|          | 0/225 [00:00<?, ?it/s][A
  4%|▍         | 9/225 [00:00<00:02, 87.10it/s][A
  8%|▊         | 18/225 [00:00<00:02, 86.11it/s][A
 12%|█▏        | 27/225 [00:00<00:02, 85.45it/s][A
 16%|█▌        | 36/225 [00:00<00:02, 85.48it/s][A
 20%|██        | 45/225 [00:00<00:02, 84.19it/s][A
 24%|██▍       | 54/225 [00:00<00:02, 83.76it/s][A
 28%|██▊       | 63/225 [00:00<00:01, 84.17it/s][A
 32%|███▏      | 72/225 [00:00<00:01, 85.42it/s][A
 36%|███▌      | 81/225 [00:00<00:01, 84.61it/s][A
 40%|████      | 90/225 [00:01<00:01, 84.52it/s][A
 44%|████▍   

 92%|█████████▏| 207/225 [00:02<00:00, 86.19it/s][A
 96%|█████████▌| 216/225 [00:02<00:00, 86.61it/s][A
100%|██████████| 225/225 [00:02<00:00, 86.23it/s][A

  0%|          | 0/225 [00:00<?, ?it/s][A
  4%|▍         | 10/225 [00:00<00:02, 90.65it/s][A
  8%|▊         | 19/225 [00:00<00:02, 88.94it/s][A
 12%|█▏        | 28/225 [00:00<00:02, 87.26it/s][A
 16%|█▋        | 37/225 [00:00<00:02, 86.60it/s][A
 20%|██        | 46/225 [00:00<00:02, 85.34it/s][A
 24%|██▍       | 55/225 [00:00<00:02, 84.18it/s][A
 28%|██▊       | 64/225 [00:00<00:01, 83.63it/s][A
 32%|███▏      | 73/225 [00:00<00:01, 83.94it/s][A
 36%|███▋      | 82/225 [00:00<00:01, 85.01it/s][A
 40%|████      | 91/225 [00:01<00:01, 85.29it/s][A
 44%|████▍     | 100/225 [00:01<00:01, 85.97it/s][A
 48%|████▊     | 109/225 [00:01<00:01, 85.96it/s][A
 52%|█████▏    | 118/225 [00:01<00:01, 86.20it/s][A
 56%|█████▋    | 127/225 [00:01<00:01, 86.62it/s][A
 60%|██████    | 136/225 [00:01<00:01, 86.42it/s][A
 64%|██████▍

 96%|█████████▌| 216/225 [00:02<00:00, 85.64it/s][A
100%|██████████| 225/225 [00:02<00:00, 85.84it/s][A

  0%|          | 0/225 [00:00<?, ?it/s][A
  4%|▍         | 9/225 [00:00<00:02, 88.44it/s][A
  8%|▊         | 18/225 [00:00<00:02, 86.92it/s][A
 12%|█▏        | 26/225 [00:00<00:02, 84.52it/s][A
 16%|█▌        | 35/225 [00:00<00:02, 84.47it/s][A
 20%|█▉        | 44/225 [00:00<00:02, 84.18it/s][A
 24%|██▎       | 53/225 [00:00<00:02, 84.23it/s][A
 28%|██▊       | 62/225 [00:00<00:01, 83.78it/s][A
 32%|███▏      | 71/225 [00:00<00:01, 83.34it/s][A
 36%|███▌      | 80/225 [00:00<00:01, 83.51it/s][A
 40%|███▉      | 89/225 [00:01<00:01, 82.47it/s][A
 44%|████▎     | 98/225 [00:01<00:01, 82.57it/s][A
 48%|████▊     | 107/225 [00:01<00:01, 82.63it/s][A
 52%|█████▏    | 116/225 [00:01<00:01, 83.84it/s][A
 56%|█████▌    | 125/225 [00:01<00:01, 84.45it/s][A
 60%|█████▉    | 134/225 [00:01<00:01, 85.13it/s][A
 64%|██████▎   | 143/225 [00:01<00:00, 84.89it/s][A
 68%|██████▊  

 12%|█▏        | 27/225 [00:00<00:02, 86.85it/s][A
 16%|█▌        | 36/225 [00:00<00:02, 86.95it/s][A
 20%|██        | 45/225 [00:00<00:02, 86.02it/s][A
 24%|██▍       | 54/225 [00:00<00:02, 85.02it/s][A
 28%|██▊       | 63/225 [00:00<00:01, 84.57it/s][A
 32%|███▏      | 72/225 [00:00<00:01, 84.03it/s][A
 36%|███▌      | 81/225 [00:00<00:01, 83.19it/s][A
 40%|████      | 90/225 [00:01<00:01, 83.07it/s][A
 44%|████▍     | 99/225 [00:01<00:01, 82.99it/s][A
 48%|████▊     | 108/225 [00:01<00:01, 83.85it/s][A
 52%|█████▏    | 117/225 [00:01<00:01, 83.88it/s][A
 56%|█████▌    | 126/225 [00:01<00:01, 83.79it/s][A
 60%|██████    | 135/225 [00:01<00:01, 83.71it/s][A
 64%|██████▍   | 144/225 [00:01<00:00, 83.66it/s][A
 68%|██████▊   | 153/225 [00:01<00:00, 83.17it/s][A
 72%|███████▏  | 162/225 [00:01<00:00, 82.47it/s][A
 76%|███████▌  | 171/225 [00:02<00:00, 83.25it/s][A
 80%|████████  | 180/225 [00:02<00:00, 83.81it/s][A
 84%|████████▍ | 189/225 [00:02<00:00, 83.96it/s][A
 8

 20%|██        | 45/225 [00:00<00:02, 85.94it/s][A
 24%|██▍       | 54/225 [00:00<00:01, 85.94it/s][A
 28%|██▊       | 63/225 [00:00<00:01, 85.94it/s][A
 32%|███▏      | 72/225 [00:00<00:01, 85.94it/s][A
 36%|███▌      | 81/225 [00:00<00:01, 84.74it/s][A
 40%|████      | 90/225 [00:01<00:01, 84.62it/s][A
 44%|████▍     | 99/225 [00:01<00:01, 85.01it/s][A
 48%|████▊     | 108/225 [00:01<00:01, 84.56it/s][A
 52%|█████▏    | 117/225 [00:01<00:01, 84.01it/s][A
 56%|█████▌    | 126/225 [00:01<00:01, 84.82it/s][A
 60%|██████    | 135/225 [00:01<00:01, 84.58it/s][A
 64%|██████▍   | 144/225 [00:01<00:00, 84.86it/s][A
 68%|██████▊   | 153/225 [00:01<00:00, 84.83it/s][A
 72%|███████▏  | 162/225 [00:01<00:00, 84.58it/s][A
 76%|███████▌  | 171/225 [00:02<00:00, 84.76it/s][A
 80%|████████  | 180/225 [00:02<00:00, 84.16it/s][A
 84%|████████▍ | 189/225 [00:02<00:00, 82.59it/s][A
 88%|████████▊ | 198/225 [00:02<00:00, 81.74it/s][A
 92%|█████████▏| 207/225 [00:02<00:00, 82.51it/s][A


 32%|███▏      | 72/225 [00:00<00:01, 86.21it/s][A
 36%|███▌      | 81/225 [00:00<00:01, 86.13it/s][A
 40%|████      | 90/225 [00:01<00:01, 85.58it/s][A
 44%|████▍     | 99/225 [00:01<00:01, 85.45it/s][A
 48%|████▊     | 108/225 [00:01<00:01, 85.11it/s][A
 52%|█████▏    | 117/225 [00:01<00:01, 84.88it/s][A
 56%|█████▌    | 126/225 [00:01<00:01, 84.48it/s][A
 60%|██████    | 135/225 [00:01<00:01, 84.19it/s][A
 64%|██████▍   | 144/225 [00:01<00:00, 84.00it/s][A
 68%|██████▊   | 153/225 [00:01<00:00, 84.10it/s][A
 72%|███████▏  | 162/225 [00:01<00:00, 83.29it/s][A
 76%|███████▌  | 171/225 [00:02<00:00, 83.01it/s][A
 80%|████████  | 180/225 [00:02<00:00, 82.82it/s][A
 84%|████████▍ | 189/225 [00:02<00:00, 83.97it/s][A
 88%|████████▊ | 198/225 [00:02<00:00, 83.84it/s][A
 92%|█████████▏| 207/225 [00:02<00:00, 84.46it/s][A
 96%|█████████▌| 216/225 [00:02<00:00, 83.95it/s][A
100%|██████████| 225/225 [00:02<00:00, 84.66it/s][A

  0%|          | 0/225 [00:00<?, ?it/s][A
  4%|▍

In [48]:
def test(net):
    correct = 0
    total = 0
    with torch.no_grad():
        for i in tqdm(range(len(test_X))):
            real_class = torch.argmax(test_y[i]).to(device)
            net_out = net(test_X[i].view(-1, 1, 50, 50).to(device))[0]

            predicted_class = torch.argmax(net_out)
            if predicted_class == real_class:
                correct += 1
            total+=1

    print("Accuracy:", round(correct/total, 3))

test(net)


  0%|          | 0/2494 [00:00<?, ?it/s][A
  2%|▏         | 53/2494 [00:00<00:04, 520.82it/s][A
  5%|▍         | 121/2494 [00:00<00:04, 559.44it/s][A
  7%|▋         | 187/2494 [00:00<00:03, 585.10it/s][A
 10%|█         | 256/2494 [00:00<00:03, 610.21it/s][A
 13%|█▎        | 322/2494 [00:00<00:03, 623.02it/s][A
 15%|█▌        | 377/2494 [00:00<00:03, 568.16it/s][A
 18%|█▊        | 444/2494 [00:00<00:03, 594.16it/s][A
 20%|██        | 505/2494 [00:00<00:03, 597.53it/s][A
 23%|██▎       | 572/2494 [00:00<00:03, 616.33it/s][A
 26%|██▌       | 639/2494 [00:01<00:02, 630.26it/s][A
 28%|██▊       | 702/2494 [00:01<00:03, 583.45it/s][A
 31%|███       | 769/2494 [00:01<00:02, 605.77it/s][A
 33%|███▎      | 830/2494 [00:01<00:03, 539.95it/s][A
 36%|███▌      | 896/2494 [00:01<00:02, 569.22it/s][A
 39%|███▊      | 964/2494 [00:01<00:02, 597.26it/s][A
 41%|████▏     | 1031/2494 [00:01<00:02, 616.06it/s][A
 44%|████▍     | 1094/2494 [00:01<00:02, 609.87it/s][A
 47%|████▋     | 11

Accuracy: 0.714





In [None]:
def fwd_pass(X, y, train= False): #train Flag
    if train = True: 
        net.zero_grad()
    outputs = net(X)
    matches = [torch.argmax(i) == torch.argmax(j) for i, j in zip(outputs, y)]
    acc = matches.count(True)/len(matches)
    loss = loss_function(outputs, y)
    
    if train:
        loss.backward()
        optimizer.step()
    return acc, loss
    
        

In [None]:
def test(size = 32):
    X, y = test_X[:size], test_y[:size]
    