In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import torchvision
import matplotlib.pyplot as plt
from ema_pytorch import EMA
from torchinfo import summary

In [2]:
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CelebA(
        root="./data",
        split="train",
        target_type="attr",
        download=True,
        transform=torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize(64),
                torchvision.transforms.ToTensor(),
            ]
        ),
    ),
    batch_size=64,
    shuffle=True,
    num_workers=2,
)
val_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CelebA(
        root="./data",
        split="valid",
        target_type="attr",
        download=True,
        transform=torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize(64),
                torchvision.transforms.ToTensor(),
            ]
        ),
    ),
    batch_size=64,
    shuffle=True,
    num_workers=2,
)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# num of attributes
print(len(train_loader.dataset.attr_names))

41


In [20]:
class DepthwiseSeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(DepthwiseSeparableConv2d, self).__init__()
        self.depthwise = nn.Conv2d(
            in_channels, in_channels, kernel_size, stride, padding, groups=in_channels
        )
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1)
        self.norm1 = nn.GroupNorm(16, out_channels)
        self.act = nn.GELU()

    def forward(self, x):
        x = self.depthwise(x)
        # x = self.act(x)

        x = self.pointwise(x)
        x = self.act(x)
        x = self.norm1(x)

        return x


class ConvStack(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        num_layers,
        kernel_size=3,
        stride=1,
        padding=1,
        conv_type=DepthwiseSeparableConv2d,
    ):
        super(ConvStack, self).__init__()
        self.layers = nn.ModuleList(
            [
                conv_type(in_channels, out_channels, kernel_size, stride, padding)
                if _ == 0
                else conv_type(out_channels, out_channels, kernel_size, stride, padding)
                for _ in range(num_layers)
            ]
        )

    def forward(self, x):
        x = self.layers[0](x)
        for layer in self.layers[1:]:
            identity = x
            x = layer(x) + identity

        return x


class Classifer(nn.Module):
    def __init__(self):
        super(Classifer, self).__init__()
        # self.backbone = nn.Sequential(
        #     ConvStack(3, 32, 2),
        #     nn.MaxPool2d(2),
        #     ConvStack(32, 64, 2),
        #     nn.MaxPool2d(2),
        #     ConvStack(64, 128, 2),
        #     nn.MaxPool2d(2),
        #     ConvStack(128, 256, 4),
        #     nn.MaxPool2d(2),
        #     ConvStack(256, 256, 4),
        # )
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 32, 3, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, 3, 2, 1),
        )
        
        self.classifier = nn.Linear(256, 40)

    def forward(self, x):
        x = self.backbone(x).view(x.size(0), -1)
        x = self.classifier(x)
        return x
    
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")


model = Classifer().to("mps")
summary(
    model,
    input_data=torch.randn(64, 3, 64, 64, device="mps", requires_grad=False),
)
test()

Testing: 100%|██████████| 311/311 [00:08<00:00, 38.74it/s]


32.702693939208984

In [21]:
ema = EMA(model, beta=0.9999, update_after_step=100, update_every=10)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
criterion = nn.CrossEntropyLoss()

In [22]:
@torch.no_grad()
def test() -> float:
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for img, label in tqdm(val_loader, desc="Testing", leave=True):
            img, label = img.to("mps"), label.to("mps")
            output = model(img)
            test_loss += criterion(output, label.float()).mean()

    test_loss /= len(val_loader)
    test_loss = test_loss.item()
    
    return test_loss

In [23]:
test_loss = 0
test_acc = 0
for epoch in range(10):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    for img, label in pbar:
        img, label = img.to("mps"), label.to("mps")
        optimizer.zero_grad()
        output = model(img)
        loss = criterion(output, label.float())
            
        loss.backward()
        optimizer.step()

        ema.update()
        pbar.set_postfix_str(f"loss: {loss.item():.4f}, test_loss: {test_loss:.4f}, test_acc: {test_acc:.4f}")

Epoch 1: 100%|██████████| 2544/2544 [00:55<00:00, 45.57it/s, loss: 27.9069, test_loss: 0.0000, test_acc: 0.0000]
Epoch 2: 100%|██████████| 2544/2544 [01:30<00:00, 28.00it/s, loss: 24.2872, test_loss: 0.0000, test_acc: 0.0000]
Epoch 3:  25%|██▌       | 645/2544 [00:33<01:39, 19.15it/s, loss: 25.3475, test_loss: 0.0000, test_acc: 0.0000]


KeyboardInterrupt: 