In [None]:
import os
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader


<h2> Image Transformations </h2>

In [None]:
#imagenet stats

mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

train_transforms = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean,std),
    transforms.RandomHorizontalFlip(p=0.5),
    # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
    # transforms.RandomRotation(degrees=15),
])

val_transforms =transforms.Compose([
    transforms.Resize((256,256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean,std)
])

<h2> Dataset Creation </h2>

In [None]:
TRAIN_DIR = "../../datasets/tiny-imagenet-200/train"
VAL_DIR = "../../datasets/tiny-imagenet-200/val"

train_dataset = ImageFolder(root=TRAIN_DIR, transform=train_transforms)
val_dataset = ImageFolder(root=VAL_DIR, transform=val_transforms)
# test_dataset = ImageFolder(root=TEST_DIR, transform=val_transforms)

os.getcwd()


In [None]:
train_loader = DataLoader(
    dataset=train_dataset,
    shuffle=True,
    batch_size=64,
    num_workers=2,
    pin_memory=True,
)

val_loader = DataLoader(
    dataset=val_dataset,
    shuffle=False,
    batch_size=64,
    num_workers=2,
    pin_memory=True,
)

<h2> Dataset Testing </h2>

In [None]:
for images, labels in train_loader:
    print(f"Image shape: {images.shape}")
    print(f"Label: {labels.shape}")
    break

<h2> Alexnet Model </h2>

In [None]:
class Alexnet(nn.Module):

    def __init__(self):
        super().__init__()

        self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=0)
        self.relu = nn.ReLU(inplace=True)
        
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=2)
        self.conv2 = nn.Conv2d(in_channels=96,out_channels=256,kernel_size=5,stride=1,padding=2)
        self.conv3 = nn.Conv2d(in_channels=256,out_channels=384,kernel_size=3,stride=1,padding=1)
        self.conv4 = nn.Conv2d(in_channels=384,out_channels=384,kernel_size=3,stride=1,padding=1)
        self.conv5 = nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,stride=1,padding=1)

        feature_extractor_layers = [self.conv1, self.relu, self.maxpool, self.conv2, self.relu,
                                    self.maxpool, self.conv3, self.relu, self.conv4, self.relu,
                                    self.conv5, self.relu, self.maxpool]

        self.feature_extractor = nn.Sequential(*feature_extractor_layers)

        self.dropout = nn.Dropout(p=0.5)

        self.fc1 = nn.Linear(in_features=9216, out_features=4096)
        self.fc2 = nn.Linear(in_features=4096, out_features=4096)
        self.fc3 = nn.Linear(in_features=4096, out_features=200)

        classifier_layers = [self.dropout, self.fc1, self.relu, self.dropout, self.fc2,
                             self.relu, self.fc3]
        self.classifier = nn.Sequential(*classifier_layers)


    def forward(self,images):
        out = self.feature_extractor(images)
        out = torch.flatten(out, 1)
        out = self.classifier(out)

