In [579]:
import torch
from torch import nn,optim
from torchvision import datasets,transforms
from torch.utils.data import dataloader, random_split
from tqdm import tqdm


In [580]:
# torch.manual_seed(137)

In [581]:
tf = transforms.Compose(
    [transforms.Grayscale(num_output_channels=1), transforms.ToTensor()]
)
# tf = transforms.ToTensor()


In [582]:
training_set = datasets.ImageFolder(
    root="data/images/512-2048-40/combined/train", transform=tf
)
testing_set = datasets.ImageFolder(
    root="data/images/512-2048-40/combined/test", transform=tf
)

In [583]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 16


In [584]:
class PyTeen(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1, 8, 5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(8, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Flatten(),
            nn.Linear(1152, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 10),
        )
        self.loss = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.parameters())

    # self.to(torch.device(DEVICE)) #gpu

    def forward(self, input):
        return self.layers(input)

    def predict(self, input):
        with torch.no_grad():
            pred = self.forward(input)
            return torch.argmax(pred, axis=-1)

    def train(self, input, label):
        self.optimizer.zero_grad()
        pred = self.forward(input)
        loss = self.loss(pred, label)
        loss.backward()
        self.optimizer.step()
        return loss

In [585]:
training_loader = dataloader.DataLoader(
    training_set, batch_size=BATCH_SIZE, shuffle=True
)
testing_loader = dataloader.DataLoader(
    testing_set, batch_size=BATCH_SIZE, shuffle=False
)


In [586]:
network = PyTeen()
network.to(torch.device(DEVICE))


PyTeen(
  (layers): Sequential(
    (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(8, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=1152, out_features=120, bias=True)
    (8): ReLU()
    (9): Linear(in_features=120, out_features=84, bias=True)
    (10): ReLU()
    (11): Linear(in_features=84, out_features=10, bias=True)
  )
  (loss): CrossEntropyLoss()
)

In [587]:
# training loop
# we need to make the tensors all .to(torch.device('cuda'))
EPOCHS = 5

for i in range(EPOCHS):
  total_loss = 0
  for input,label in tqdm(training_loader):
      input = input.to(torch.device(DEVICE))
      label = label.to(torch.device(DEVICE))

      loss = network.train(input,label)
      total_loss += loss
  print("EPOCH:",i+1,": ",total_loss)


100%|██████████| 3781/3781 [00:41<00:00, 90.62it/s]


EPOCH: 1 :  tensor(1718.1300, grad_fn=<AddBackward0>)


100%|██████████| 3781/3781 [00:40<00:00, 93.34it/s]


EPOCH: 2 :  tensor(468.3240, grad_fn=<AddBackward0>)


100%|██████████| 3781/3781 [00:39<00:00, 95.92it/s]


EPOCH: 3 :  tensor(319.9427, grad_fn=<AddBackward0>)


100%|██████████| 3781/3781 [00:39<00:00, 95.70it/s]


EPOCH: 4 :  tensor(247.5844, grad_fn=<AddBackward0>)


100%|██████████| 3781/3781 [00:39<00:00, 96.80it/s]

EPOCH: 5 :  tensor(204.1053, grad_fn=<AddBackward0>)





In [588]:
# evaluation loop
# we need to make the tensors all .to(torch.device('cuda'))
num_corrects = 0
for input,label in tqdm(testing_loader):
  input = input.to(torch.device(DEVICE))
  label = label.to(torch.device(DEVICE))
  pred = network.predict(input)
  for i in range(len(pred)):
    if(pred[i] == label[i]):
      num_corrects += 1
print(f"{num_corrects*100/(len(testing_loader)*BATCH_SIZE)}%")


100%|██████████| 586/586 [00:03<00:00, 179.41it/s]

96.18174061433447%





In [589]:
torch.save(network.state_dict(), "./audio_mnist.pth")
