In [1]:
import os
import cv2
import numpy as np
from tqdm import tqdm

REBUILD_DATA = True


class LikesVSDislikes():
    IMG_SIZE = 50
    LIKES = "../images/like_images"
    DISLIKES = "../images/dislike_images"
    LABELS = {LIKES: 0, DISLIKES: 1}

    training_data = []

    likecount = 0
    dislikecount = 0

    def make_training_data(self):
        for label in self.LABELS:
            print(label)
            for f in tqdm(os.listdir(label)):
                try:
                    path = os.path.join(label, f)
                    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
                    img = cv2.resize(img, (self.IMG_SIZE, self.IMG_SIZE))
                    self.training_data.append([np.array(img), np.eye(2)[self.LABELS[label]]])

                    if label == self.LIKES:
                        self.likecount += 1
                    elif label == self.DISLIKES:
                        self.dislikecount += 1
                except Exception as e:
                    print(str(e))
                    pass
        np.random.shuffle(self.training_data)
        np.save("training_data.npy", self.training_data)
        print("Likes:", self.likecount)
        print("Dislikes:", self.dislikecount)


if REBUILD_DATA:
    likevsdislike = LikesVSDislikes()
    likevsdislike.make_training_data()

training_data = np.load("training_data.npy", allow_pickle=True)

  1%|          | 8/662 [00:00<00:08, 74.13it/s]

../images/like_images


100%|██████████| 662/662 [00:07<00:00, 87.89it/s]
  2%|▏         | 9/590 [00:00<00:06, 83.99it/s]

../images/dislike_images


100%|██████████| 590/590 [00:06<00:00, 88.32it/s]
  return array(a, dtype, copy=False, order=order, subok=True)


Likes: 662
Dislikes: 590


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.conv3 = nn.Conv2d(64, 128, 5)
        
        x = torch.randn(50,50).view(-1,1,50,50)
        self._to_linear = None
        self.convs(x)
        
        self.fc1 = nn.Linear(self._to_linear,512)
        self.fc2 = nn.Linear(512, 2)
    
        
    def convs(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2,2))
        x = F.max_pool2d(F.relu(self.conv3(x)), (2,2))
        
        print(x[0].shape)
        
        if self._to_linear is None:
            self._to_linear = x[0].shape[0]*x[0].shape[1]*x[0].shape[2]
        return x
            
    def forward(self, x):
        x = self.convs(x)
        x = x.view(-1, self._to_linear)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x, dim=1)
    
net = Net()

torch.Size([128, 2, 2])


In [3]:
import torch.optim as optim

optimizer = optim.Adam(net.parameters(), lr = .001)
loss_function = nn.MSELoss()

In [4]:
X = torch.Tensor([i[0] for i in training_data]).view(-1, 50, 50)
X = X/255.0
y = torch.Tensor([i[1] for i in training_data])

In [5]:
VAL_PCT = .25
val_size = int(len(X)*VAL_PCT)
print(val_size)

313


In [6]:
train_X = X[:-val_size]
train_y = y[:-val_size]

test_X = X[-val_size:]
test_y = y[-val_size:]

print(len(train_X))
print(len(test_X))

939
313


In [7]:
BATCH_SIZE = 10
EPOCHS = 3 

for epoch in range(EPOCHS):
    for i in tqdm(range(0, len(train_X), BATCH_SIZE)):
        batch_X = train_X[i:i+BATCH_SIZE].view(-1,1,50,50)
        batch_y = train_y[i:i+BATCH_SIZE]
        
        net.zero_grad()
        outputs = net(batch_X)
        loss = loss_function(outputs, batch_y)
        loss.backward()
        optimizer.step()
print(loss)

  2%|▏         | 2/94 [00:00<00:04, 19.88it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


  7%|▋         | 7/94 [00:00<00:04, 20.33it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 13%|█▎        | 12/94 [00:00<00:04, 20.27it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 19%|█▉        | 18/94 [00:00<00:03, 20.60it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 22%|██▏       | 21/94 [00:01<00:03, 20.61it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 26%|██▌       | 24/94 [00:01<00:03, 20.60it/s]

torch.Size([128, 2, 2])


 28%|██▊       | 26/94 [00:01<00:03, 20.30it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 31%|███       | 29/94 [00:01<00:03, 20.61it/s]

torch.Size([128, 2, 2])


 34%|███▍      | 32/94 [00:01<00:02, 20.68it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 40%|████      | 38/94 [00:01<00:02, 20.67it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 44%|████▎     | 41/94 [00:01<00:02, 20.92it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 47%|████▋     | 44/94 [00:02<00:02, 20.80it/s]

torch.Size([128, 2, 2])


 50%|█████     | 47/94 [00:02<00:02, 20.76it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 56%|█████▋    | 53/94 [00:02<00:01, 21.07it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 60%|█████▉    | 56/94 [00:02<00:01, 21.36it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 63%|██████▎   | 59/94 [00:02<00:01, 20.99it/s]

torch.Size([128, 2, 2])


 66%|██████▌   | 62/94 [00:02<00:01, 21.14it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 72%|███████▏  | 68/94 [00:03<00:01, 21.07it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 76%|███████▌  | 71/94 [00:03<00:01, 20.90it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 79%|███████▊  | 74/94 [00:03<00:00, 21.29it/s]

torch.Size([128, 2, 2])


 82%|████████▏ | 77/94 [00:03<00:00, 21.32it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 88%|████████▊ | 83/94 [00:03<00:00, 22.14it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 95%|█████████▍| 89/94 [00:04<00:00, 23.01it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


100%|██████████| 94/94 [00:04<00:00, 21.38it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


  6%|▋         | 6/94 [00:00<00:03, 25.55it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 13%|█▎        | 12/94 [00:00<00:03, 26.16it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 19%|█▉        | 18/94 [00:00<00:03, 24.57it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 22%|██▏       | 21/94 [00:00<00:03, 23.31it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 29%|██▊       | 27/94 [00:01<00:03, 21.67it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 35%|███▌      | 33/94 [00:01<00:02, 23.71it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 41%|████▏     | 39/94 [00:01<00:02, 24.06it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 45%|████▍     | 42/94 [00:01<00:02, 24.70it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 51%|█████     | 48/94 [00:01<00:01, 24.64it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 61%|██████    | 57/94 [00:02<00:01, 26.46it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 64%|██████▍   | 60/94 [00:02<00:01, 26.70it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 70%|███████   | 66/94 [00:02<00:01, 25.48it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 77%|███████▋  | 72/94 [00:02<00:00, 26.22it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 83%|████████▎ | 78/94 [00:03<00:00, 26.57it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 89%|████████▉ | 84/94 [00:03<00:00, 26.89it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 99%|█████████▉| 93/94 [00:03<00:00, 27.46it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


100%|██████████| 94/94 [00:03<00:00, 25.51it/s]
  3%|▎         | 3/94 [00:00<00:03, 27.10it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 10%|▉         | 9/94 [00:00<00:03, 27.24it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 16%|█▌        | 15/94 [00:00<00:02, 27.38it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 22%|██▏       | 21/94 [00:00<00:02, 25.35it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 26%|██▌       | 24/94 [00:00<00:02, 23.46it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 32%|███▏      | 30/94 [00:01<00:02, 22.02it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 38%|███▊      | 36/94 [00:01<00:02, 22.51it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 41%|████▏     | 39/94 [00:01<00:02, 21.86it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 48%|████▊     | 45/94 [00:01<00:02, 21.31it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 54%|█████▍    | 51/94 [00:02<00:02, 21.39it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 57%|█████▋    | 54/94 [00:02<00:01, 21.20it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 64%|██████▍   | 60/94 [00:02<00:01, 21.73it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 70%|███████   | 66/94 [00:02<00:01, 20.81it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 73%|███████▎  | 69/94 [00:03<00:01, 20.74it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 80%|███████▉  | 75/94 [00:03<00:00, 21.05it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 86%|████████▌ | 81/94 [00:03<00:00, 21.03it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 89%|████████▉ | 84/94 [00:03<00:00, 21.00it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 96%|█████████▌| 90/94 [00:04<00:00, 21.65it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


100%|██████████| 94/94 [00:04<00:00, 22.26it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
tensor(0.2429, grad_fn=<MseLossBackward>)





In [8]:
correct = 0
total = 0

with torch.no_grad():
    for i in tqdm(range(len(test_X))):
        real_class = torch.argmax(test_y[i])
        net_out = net(test_X[i].view(-1,1,50,50))[0]
        
        print(real_class, net_out)
        
        predicted_class = torch.argmax(net_out)
        if predicted_class == real_class:
            correct += 1
        total += 1

print('Accuracy:', round(correct/total,3))

  8%|▊         | 24/313 [00:00<00:01, 238.48it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771

 23%|██▎       | 73/313 [00:00<00:01, 238.39it/s]

tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(

 40%|███▉      | 125/313 [00:00<00:00, 246.88it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771

 58%|█████▊    | 183/313 [00:00<00:00, 263.73it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771

 76%|███████▋  | 239/313 [00:00<00:00, 269.56it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771

 85%|████████▍ | 266/313 [00:01<00:00, 265.97it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771

 94%|█████████▎| 293/313 [00:01<00:00, 258.21it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(0) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771

100%|██████████| 313/313 [00:01<00:00, 256.52it/s]

torch.Size([128, 2, 2])
tensor(1) tensor([0.5229, 0.4771])
Accuracy: 0.537





In [9]:
def fwd_pass(X,y, train=False):
    if train:
        net.zero_grad()
    outputs = net(X)
    matches = [torch.argmax(i) == torch.argmax(j) for i,j in zip(outputs, y)]
    acc = matches.count(True)/len(matches)
    loss = loss_function(outputs, y)
    
    if train:
        loss.backward()
        optimizer.step()
    return acc, loss

In [10]:
import numpy as np
def test(size=32):
    random_start = np.random.randint(len(test_X)-size)
    X,y = test_X[random_start:random_start+size], test_y[random_start:random_start+size]
    with torch.no_grad():
        val_acc, val_loss = fwd_pass(X.view(-1,1,50,50),y)
    return val_acc, val_loss

val_acc, val_loss = test(size=32)
print(val_acc,val_loss)

torch.Size([128, 2, 2])
0.4375 tensor(0.2534)


In [11]:
import time

MODEL_NAME = f"model-{int(time.time())}"

net = Net()
optimizer = optim.Adam(net.parameters(), lr = .001)
loss_function = nn.MSELoss()

print(MODEL_NAME)

torch.Size([128, 2, 2])
model-1598618390


In [12]:
def train():
    BATCH_SIZE = 10
    EPOCHS = 1
    with open("model.log", "a") as f:
        for epoch in range(EPOCHS):
            for i in tqdm(range(0,len(train_X), BATCH_SIZE)):
                batch_X = train_X[i:i+BATCH_SIZE].view(-1,1,50,50)
                batch_y = train_y[i:i+BATCH_SIZE]
                
                acc, loss = fwd_pass(batch_X, batch_y, train=True)
                if i & 50 == 0:
                    val_acc, val_loss = test(size=32)
                    f.write(f"{MODEL_NAME}, {round(time.time(),3)},{round(float(acc),2)},{round(float(loss),4)},{round(float(val_acc),2)},{round(float(val_loss),4)}")

train()

  0%|          | 0/94 [00:00<?, ?it/s]

torch.Size([128, 2, 2])


  1%|          | 1/94 [00:00<00:11,  8.01it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


  3%|▎         | 3/94 [00:00<00:09,  9.69it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


  5%|▌         | 5/94 [00:00<00:07, 11.41it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


  7%|▋         | 7/94 [00:00<00:06, 13.09it/s]

torch.Size([128, 2, 2])


 10%|▉         | 9/94 [00:00<00:05, 14.49it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 13%|█▎        | 12/94 [00:00<00:05, 16.07it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 16%|█▌        | 15/94 [00:00<00:05, 15.65it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 19%|█▉        | 18/94 [00:01<00:04, 17.00it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 22%|██▏       | 21/94 [00:01<00:04, 16.49it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 26%|██▌       | 24/94 [00:01<00:03, 17.70it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 29%|██▊       | 27/94 [00:01<00:03, 17.08it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 32%|███▏      | 30/94 [00:01<00:03, 18.33it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 35%|███▌      | 33/94 [00:01<00:03, 17.75it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 38%|███▊      | 36/94 [00:01<00:03, 19.30it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 41%|████▏     | 39/94 [00:02<00:02, 19.65it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 45%|████▍     | 42/94 [00:02<00:02, 19.82it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 48%|████▊     | 45/94 [00:02<00:02, 19.64it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 50%|█████     | 47/94 [00:02<00:02, 16.69it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 53%|█████▎    | 50/94 [00:02<00:02, 18.03it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 56%|█████▋    | 53/94 [00:02<00:02, 17.16it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 60%|█████▉    | 56/94 [00:03<00:02, 18.80it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 63%|██████▎   | 59/94 [00:03<00:01, 18.28it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 66%|██████▌   | 62/94 [00:03<00:01, 20.03it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 69%|██████▉   | 65/94 [00:03<00:01, 19.48it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 72%|███████▏  | 68/94 [00:03<00:01, 20.50it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 76%|███████▌  | 71/94 [00:03<00:01, 21.26it/s]

torch.Size([128, 2, 2])


 79%|███████▊  | 74/94 [00:03<00:00, 21.75it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 82%|████████▏ | 77/94 [00:04<00:00, 22.16it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 88%|████████▊ | 83/94 [00:04<00:00, 21.61it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 91%|█████████▏| 86/94 [00:04<00:00, 20.61it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 95%|█████████▍| 89/94 [00:04<00:00, 22.26it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


 98%|█████████▊| 92/94 [00:04<00:00, 20.86it/s]

torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])


100%|██████████| 94/94 [00:04<00:00, 19.45it/s]
