In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class AlexNet(nn.Module):
    
    def __init__(self, num_classes=1000):
        super(AlexNet,self).__init__()
        
        self.conv = nn.Sequential(
            # conv1
            nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4)
            nn.ReLU(),
            nn.LocalResponseNorm(size=5,k=2,alpha=10e-4,beta=0.75),
            nn.MaxPool2d(kernel_size=3,stride=2),
            
            # conv2
            nn.Conv2d(in_channels=96,out_channels=256,kernel_size=5,stride=1,padding=2)
            nn.ReLU(),
            nn.LocalResponseNorm(size=5,k=2,alpha=10e-4,beta=0.75),
            nn.MaxPool2d(kernel_size=3,stride=2),
            
            # conv3
            nn.Conv2d(in_channels=256,out_channels=384,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            # conv4
            nn.Conv2d(in_channels=384,out_channels=384,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            # conv5
            nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2)  
        )
        
        self.fully = nn.Sequential(
            # fully1
            nn.Linear(in_features=6*6*256,out_features=4096),
            nn.ReLU(),
            nn.Dropout(p=0.5)
            
            # fully2
            nn.Linear(in_features=4096,out_features=4096),
            nn.ReLU(),
            nn.Dropout(p=0.5)
            
            # fully3
            nn.Linear(in_features=4096,out_features=num_classes)  
        )
        
    def forward(self,x):
        x = self.conv(x)
        x = torch.flatten(x)
        x = self.fully(x)
        
        return x
        