In [None]:
import torch
import numpy as np
import torchvision
import torch.nn as nn
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import time

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((227, 227)),  # Resize images to 32x32 pixels
    torchvision.transforms.ToTensor(),        # Convert images to PyTorch tensors
    torchvision.transforms.Normalize(mean=(0.5,), std=(0.5,))  # Normalize with mean=0.5, std=0.5
])

train_dataset = torchvision.datasets.Imagenette(
    root='/home/kami/Documents/datasets/',
    size= "160px",
    download=False,
    transform=transform
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True
)


# Example: Iterate through one batch to verify the data
for images, labels in train_loader:
    print(f"Batch shape: {images.shape}")  # Should be [64, 1, 32, 32] (batch, channels, height, width)
    print(f"Labels shape: {labels.shape}")  # Should be [64]
    print(f"Image tensor min: {images.min()}, max: {images.max()}")  # Check normalization
    # print()
    break  # Only print the first batch




In [None]:
class AlexNet(nn.Module):
    def __init__(self, num_classes=10, in_channels=1):
        super(AlexNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels, 96, kernel_size=11, stride=4, padding=0),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2))
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU())
        self.layer4 = nn.Sequential(
            nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU())
        self.layer5 = nn.Sequential(
            nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2))
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(9216, 4096),
            nn.ReLU())
        self.fc1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU())
        self.fc2= nn.Sequential(
            nn.Linear(4096, num_classes))

    def forward(self, x):
        # print("size 01:" , x.shape)
        out = self.layer1(x)
        # print("size 02:" , out.shape)
        out = self.layer2(out)
        # print("size 03:" , out.shape)
        out = self.layer3(out)
        # print("size 04:" , out.shape)
        out = self.layer4(out)
        # print("size 05:" , out.shape)
        out = self.layer5(out)
        # print("size 06:" , out.shape)
        out = out.reshape(out.size(0), -1)
        # print("size 07:" , out.shape)
        out = self.fc(out)
        # print("size 08:" , out.shape)
        out = self.fc1(out)
        out = self.fc2(out)
        return out

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# device = torch.device('cpu')
model = AlexNet(num_classes=10,in_channels=3).to(device)
criterion = nn.CrossEntropyLoss()  # Combines log softmax and NLL loss
optimizer = optim.Adam(model.parameters(), lr=1e-3)  # Adam with learning rate 1e-3

st = time.time()
# Train for 10 epochs
model.train()
for epoch in range(10):
    batch_index = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        batch_index += 1
        if batch_index % 100 == 0:
            print(f"Epoch: {epoch} | Batch: {batch_index} | Loss: {loss.item():.4f}")

et = time.time()
print(et-st)
print("Training completed.")