In [22]:
import torch.nn as nn
import torchvision
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from random import choices, sample
import numpy as np
import random
from sklearn.model_selection import train_test_split
import warnings

In [23]:
device = "cuda" if torch.cuda.is_available() else "cpu"
seed = 42
warnings.filterwarnings("ignore")

In [24]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [25]:
mock_images = torch.randint(0, 256, (500, 3, 227, 227)).float().to(device)

In [26]:
class AlexNet(nn.Module):

  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(3, 96, kernel_size=(11, 11), stride=4)
    nn.init.normal_(self.conv1.weight, mean=0, std=0.01)
    nn.init.zeros_(self.conv1.bias)

    self.maxpool = nn.MaxPool2d(kernel_size=(3, 3), stride=2)
    self.localresponsenorm = nn.LocalResponseNorm(size=5, alpha = 5*10**(-4), k=2, beta=0.75)

    self.conv2 = nn.Conv2d(96, 256, kernel_size=(5, 5), padding=2)
    nn.init.normal_(self.conv2.weight, mean=0, std=0.01)
    nn.init.ones_(self.conv2.bias)

    self.conv3 = nn.Conv2d(256, 384, kernel_size=(3, 3), padding=1)
    nn.init.normal_(self.conv3.weight, mean=0, std=0.01)
    nn.init.zeros_(self.conv3.bias)

    self.conv4 = nn.Conv2d(384, 384, kernel_size=(3, 3), padding=1)
    nn.init.normal_(self.conv4.weight, mean=0, std=0.01)
    nn.init.ones_(self.conv4.bias)

    self.conv5 = nn.Conv2d(384, 256, kernel_size=(3, 3), padding=1)
    nn.init.normal_(self.conv5.weight, mean=0, std=0.01)
    nn.init.ones_(self.conv5.bias)

    self.fc1 = nn.Linear(9216, 4096)
    nn.init.normal_(self.fc1.weight, mean=0, std=0.01)
    nn.init.ones_(self.fc1.bias)

    self.fc2 = nn.Linear(4096, 4096)
    nn.init.normal_(self.fc2.weight, mean=0, std=0.01)
    nn.init.ones_(self.fc2.bias)

    self.fc3 = nn.Linear(4096, 1000)
    nn.init.normal_(self.fc3.weight, mean=0, std=0.01)
    nn.init.ones_(self.fc3.bias)

  def forward(self, x):
    x = nn.functional.relu(self.conv1(x))
    x = self.localresponsenorm(x)
    x = self.maxpool(x)
    x = nn.functional.relu(self.conv2(x))
    x = self.localresponsenorm(x)
    x = self.maxpool(x)
    x = nn.functional.relu(self.conv3(x))
    x = nn.functional.relu(self.conv4(x))
    x = nn.functional.relu(self.conv5(x))
    x = self.maxpool(x)
    x = torch.flatten(x, 1, -1)
    x = nn.functional.dropout(nn.functional.relu(self.fc1(x)), p=0.5)
    x = nn.functional.dropout(nn.functional.relu(self.fc2(x)), p=0.5)

    return nn.functional.softmax(self.fc3(x))


In [27]:
model = AlexNet().to(device)
out = model(mock_images)
