# AlexNet 만들기

alexnet 구조
![image.png](attachment:7a0d3fb0-cc45-4354-963c-3a66152fce8d.png)![image.png](attachment:322c25ed-cd4a-4a95-a10a-3ecd6b170bcd.png)

![image.png](attachment:c0af5a7e-37c9-4f64-b414-e60218230396.png)

In [25]:
import torch
from torch import nn
import torch.nn.functional as F


class Alexnet(nn.Module):
    def __init__(self,num_classes=1000):
        super(Alexnet, self).__init__()
        #C1 : in-channel = 3, out-channel = 96, filter = 11x11
        self.conv1 = nn.Conv2d(3,96,11,stride=4)
        #C2
        self.conv2 = nn.Conv2d(96,256,5,stride=1)
        #C3
        self.conv3 = nn.Conv2d(256,384,3,stride=1)
        #C4
        self.conv4 = nn.Conv2d(384,384,3,stride=1)
        #C5
        self.conv5 = nn.Conv2d(384,256,3,stride=1)
        
        #maxpooling
        self.pool = nn.MaxPool2d(3,2)
        
        #F6
        self.fc1 = nn.Linear(256 * 2 * 2, 4096)
        #F7
        self.fc2 = nn.Linear(4096, 4096)
        #F8
        self.fc3 = nn.Linear(4096, num_classes)
    
    def forward(self,x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))     
        x = F.relu(self.conv4(x))
        x = self.pool(F.relu(self.conv5(x)))
        
        x = x.view(x.size(0), 256 * 2 * 2)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x),dim=1)
        return x

In [28]:
sample_in = torch.randn(1,3,227,227)
print(sample_in.shape)

alexnet = Alexnet()
outputs = alexnet(sample_in)
print(outputs.shape)

torch.Size([1, 3, 227, 227])
torch.Size([1, 1000])
