# Baseline Transformer models for the Galaxy challenge

# Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

plt.rcParams["axes.grid"] = False

In [None]:
# cuda setup
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"GPU available: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

print(device)

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Watermarked data

In [None]:
sigma = 0.03

In [None]:
offset = 0.01

In [None]:
# alpha = 0.02

In [None]:
train_images, train_ns = torch.load(
    f"../data/315/watermarked/dataset_train_barsplit_post_{sigma}_{offset}.pt"
)
test_images, test_ns = torch.load(
    f"../data/315/watermarked/dataset_test_barsplit_post_{sigma}_{offset}.pt"
)

In [None]:
# train_images, train_ns = torch.load(
#     f"../data/315/watermarked/dataset_train_barsplit_{sigma}_{offset}.pt"
# )
# test_images, test_ns = torch.load(
#     f"../data/315/watermarked/dataset_test_barsplit_{sigma}_{offset}.pt"
# )

In [None]:
# train_images, train_ns = torch.load(
#     f"../data/315/watermarked/dataset_train_barrandom_{sigma}_{alpha}.pt"
# )
# test_images, test_ns = torch.load(
#     f"../data/315/watermarked/dataset_test_barrandom_{sigma}_{alpha}.pt"
# )

In [None]:
(
    train_images.shape,
    train_ns.shape,
    test_images.shape,
    test_ns.shape,
)  # , test_images_kaggle.shape, test_ns_kaggle.shape

In [None]:
train_images = train_images.unsqueeze(1)
test_images = test_images.unsqueeze(1)

In [None]:
train_images.shape, train_ns.shape, train_images.dtype, test_images.dtype

In [None]:
train_dataset = TensorDataset(train_images, train_ns)
test_dataset = TensorDataset(test_images, test_ns)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

In [None]:
# kaggle_dataset = TensorDataset(test_images_kaggle, test_ns_kaggle)

In [None]:
# NOTE: don't shuffle for the kaggle dataset
# kaggle_loader = DataLoader(kaggle_dataset, batch_size=32, shuffle=False, num_workers=2)

In [None]:
i = torch.randint(1, 10000)
plt.imshow(train_images[i].squeeze(), cmap="gray")
plt.title(f"{train_ns[i]}");

In [None]:
i = torch.randint(1, 2500)
plt.imshow(test_images[i].squeeze(), cmap="gray")
plt.title(f"{test_ns[i]}");

# Utils

In [None]:
dim = 50

In [None]:
n_classes = 7

In [None]:
num_epochs = 40
lr = 1e-3

In [None]:
model_path = "../res/models/galaxy_challenge2/"

In [None]:
def train_model(model, train_loader, test_loader, num_epochs, lr, model_save_path):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_losses = []
    test_losses = []

    pbar = tqdm(
        total=num_epochs,
        desc=f"Epoch: 0/{num_epochs} | Train Loss: N/A | Test Loss: N/A",
    )

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        train_loss = running_loss / len(train_loader)
        train_losses.append(train_loss)

        model.eval()
        test_running_loss = 0.0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                test_running_loss += loss.item()

        test_loss = test_running_loss / len(test_loader)
        test_losses.append(test_loss)

        pbar.set_description(
            f"Epoch: {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} "
            f"| Test Loss: {test_loss:.4f}"
        )
        pbar.update(1)

        if (epoch + 1) % 10 == 0:
            checkpoint_path = f"{model_save_path}_epoch_{epoch + 1}.pth"
            torch.save(model.state_dict(), checkpoint_path)

    return train_losses, test_losses

In [None]:
def plot_losses_and_evaluate(model, train_loader, test_loader, train_losses, test_losses):
    # Plot losses
    plt.figure()
    plt.plot(train_losses, label="train")
    plt.plot(test_losses, label="test")
    plt.ylim(0, max(train_losses + test_losses))
    plt.legend()
    plt.title("Train vs Test Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()

    # Evaluate accuracies
    model.eval()

    correct_test = 0
    total_test = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = (
                images.to(next(model.parameters()).device),
                labels.to(next(model.parameters()).device),
            )
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()
    test_accuracy = 100 * correct_test / total_test if total_test else 0

    correct_train = 0
    total_train = 0
    with torch.no_grad():
        for images, labels in train_loader:
            images, labels = (
                images.to(next(model.parameters()).device),
                labels.to(next(model.parameters()).device),
            )
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()
    train_accuracy = 100 * correct_train / total_train if total_train else 0

    print(f"Train Accuracy: {train_accuracy:.2f}%")
    print(f"Test Accuracy: {test_accuracy:.2f}%")

# ChatGPT CNN from previous challenge

In [None]:
class StarCounterCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # Input shape: (batch_size, 1, 50, 50)
        self.conv1 = nn.Conv2d(
            in_channels=1, out_channels=16, kernel_size=3, padding=1
        )  # (16, 50, 50)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # (16, 25, 25)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)  # (32, 25, 25)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # (32, 12, 12)

        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # (64, 12, 12)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)  # (64, 6, 6)

        self.fc1 = nn.Linear(64 * 6 * 6, 128)
        self.fc2 = nn.Linear(128, 7)  # 7 output classes for 0-6 stars

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)

        x = F.relu(self.conv2(x))
        x = self.pool2(x)

        x = F.relu(self.conv3(x))
        x = self.pool3(x)

        x = x.view(x.size(0), -1)  # flatten
        x = F.relu(self.fc1(x))
        return self.fc2(x)

In [None]:
model_cgpt = StarCounterCNN()
model_cgpt = model_cgpt.to(device)

In [None]:
train_losses, test_losses = train_model(
    model_cgpt,
    train_loader,
    test_loader,
    num_epochs,
    lr,
    f"{model_path}cgpt_barsplit_post_{sigma}_{offset}",
)

## Split watermark, post-noise $\sigma = 0.03$, `offset = 0.01`

In [None]:
plot_losses_and_evaluate(model_cgpt, train_loader, test_loader, train_losses, test_losses)

## Split watermark, post-noise $\sigma = 0.03$, `offset = 0.05`

In [None]:
plot_losses_and_evaluate(model_cgpt, train_loader, test_loader, train_losses, test_losses)

## Split watermark, $\sigma = 0.04$, `offset = 0.01`

In [None]:
plot_losses_and_evaluate(model_cgpt, train_loader, test_loader, train_losses, test_losses)

## Split watermark, $\sigma = 0.02$, `offset = 0.01`

In [None]:
plot_losses_and_evaluate(model_cgpt, train_loader, test_loader, train_losses, test_losses)

## Bar watermark, $\sigma = 0.05$, $\alpha = 0.02$

In [None]:
plot_losses_and_evaluate(model_cgpt, train_loader, test_loader, train_losses, test_losses)

## Bar watermark, $\sigma = 0.02$, $\alpha = 0.02$

In [None]:
plot_losses_and_evaluate(model_cgpt, train_loader, test_loader, train_losses, test_losses)

## $\sigma = 0.02$, `offset = 0.005`

In [None]:
plt.plot(train_losses, label="train")
plt.plot(test_losses, label="test")
plt.ylim(0, max(train_losses + test_losses))
plt.legend();

In [None]:
model_cgpt.eval()
correct = 0
total = 0

correct_train = 0
total_train = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model_cgpt(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model_cgpt(images)
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

print(f"Train Accuracy: {100 * correct_train / total_train:.2f}%")
print(f"Test Accuracy: {100 * correct / total:.2f}%")

## $\sigma = 0.0125$, `offset = 0.01`

In [None]:
plt.plot(train_losses, label="train")
plt.plot(test_losses, label="test")
plt.ylim(0, max(train_losses + test_losses))
plt.legend();

In [None]:
model_cgpt.eval()
correct = 0
total = 0

correct_train = 0
total_train = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model_cgpt(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model_cgpt(images)
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

print(f"Train Accuracy: {100 * correct_train / total_train:.2f}%")
print(f"Test Accuracy: {100 * correct / total:.2f}%")

# VGG from prev challenge

In [None]:
from torchvision import models

In [None]:
def make_vgg11(device):
    model = models.vgg11_bn(pretrained=False)

    # Need the first conv layer to accept 1 channel instead of 3.
    model.features[0] = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)

    model.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    # Modify the classifier from CIFAR10
    model.classifier = nn.Sequential(
        nn.Linear(512, 256), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Linear(256, 7)
    )

    return model.to(device)

In [None]:
model_vgg = make_vgg11(device)

In [None]:
train_losses, test_losses = train_model(
    model_vgg,
    train_loader,
    test_loader,
    num_epochs,
    lr=1e-3,
    model_save_path=f"{model_path}vgg11_barsplit_post_{sigma}_{offset}",
)

## Split watermark, post-noise $\sigma = 0.03$, `offset = 0.01`

In [None]:
plot_losses_and_evaluate(model_vgg, train_loader, test_loader, train_losses, test_losses)

## Split watermark, $\sigma = 0.04$, `offset = 0.01`

In [None]:
plot_losses_and_evaluate(model_vgg, train_loader, test_loader, train_losses, test_losses)

## Bar watermark, $\sigma = 0.05$, $\alpha = 0.02$

In [None]:
plot_losses_and_evaluate(model_vgg, train_loader, test_loader, train_losses, test_losses)

## Bar watermark, $\sigma = 0.02$, $\alpha = 0.02$

In [None]:
plot_losses_and_evaluate(model_vgg, train_loader, test_loader, train_losses, test_losses)

## $\sigma = 0.0125$, `offset = 0.01`

In [None]:
plt.plot(train_losses, label="train")
plt.plot(test_losses, label="test")
plt.ylim(0, max(train_losses + test_losses))
plt.legend();

In [None]:
model_vgg.eval()
correct = 0
total = 0

correct_train = 0
total_train = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model_vgg(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model_vgg(images)
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()


print(f"Train Accuracy: {100 * correct_train / total_train:.2f}%")
print(f"Test Accuracy: {100 * correct / total:.2f}%")

# CNN + Attention

In [None]:
class GalaxyTransformer(nn.Module):
    def __init__(self, num_classes=7, embed_dim=128, num_heads=4, num_layers=2, dropout=0.1):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d((16, 16)),
        )

        self.embedding_proj = nn.Conv2d(64, embed_dim, kernel_size=1)  # -> (B, embed_dim, 16, 16)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, dropout=dropout
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.classifier = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes))

    def forward(self, x):
        x = self.conv_block(x)
        x = self.embedding_proj(x)
        B, C, H, W = x.shape

        x = x.view(B, C, H * W)
        x = x.permute(2, 0, 1)

        x = self.transformer_encoder(x)

        x = x.mean(dim=0)

        return self.classifier(x)

In [None]:
model_small = GalaxyTransformer(
    num_classes=n_classes, embed_dim=64, num_heads=2, num_layers=2, dropout=0.05
).to(device)

In [None]:
model_normal = GalaxyTransformer(
    num_classes=n_classes, embed_dim=128, num_heads=4, num_layers=3, dropout=0.025
).to(device)

In [None]:
train_losses, test_losses = train_model(
    model_small,
    train_loader,
    test_loader,
    100,
    lr=3e-4,
    model_save_path=f"{model_path}CNN_transformer__{sigma}_{alpha}",
)

## Split watermark, post-noise $\sigma = 0.03$, `offset = 0.01`

In [None]:
plot_losses_and_evaluate(model_small, train_loader, test_loader, train_losses, test_losses)

## Split watermark, $\sigma = 0.04$, `offset = 0.01`

In [None]:
plot_losses_and_evaluate(model_normal, train_loader, test_loader, train_losses, test_losses)

In [None]:
plot_losses_and_evaluate(model_small, train_loader, test_loader, train_losses, test_losses)

## Split watermark, $\sigma = 0.02$, `offset = 0.01`

In [None]:
plot_losses_and_evaluate(model_small, train_loader, test_loader, train_losses, test_losses)

## Bar watermark, $\sigma = 0.05$, $\alpha = 0.02$

In [None]:
model = GalaxyTransformer(
    num_classes=n_classes, embed_dim=128, num_heads=4, num_layers=3, dropout=0.025
).to(device)
model.load_state_dict(
    torch.load(f"{model_path}CNN_transformer_normal_barrandom_{sigma}_{alpha}_epoch_50.pth")
)

In [None]:
model.eval()  # set model to evaluation mode

correct_test = 0
total_test = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = (
            images.to(next(model.parameters()).device),
            labels.to(next(model.parameters()).device),
        )
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total_test += labels.size(0)
        correct_test += (predicted == labels).sum().item()
test_accuracy = 100 * correct_test / total_test if total_test else 0

correct_train = 0
total_train = 0
with torch.no_grad():
    for images, labels in train_loader:
        images, labels = (
            images.to(next(model.parameters()).device),
            labels.to(next(model.parameters()).device),
        )
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()
train_accuracy = 100 * correct_train / total_train if total_train else 0

print(f"Train Accuracy: {train_accuracy:.2f}%")
print(f"Test Accuracy: {test_accuracy:.2f}%")

## Bar watermark, $\sigma = 0.02$, $\alpha = 0.02$

In [None]:
plot_losses_and_evaluate(model_normal, train_loader, test_loader, train_losses, test_losses)

## $\sigma = 0.02$, `offset = 0.005`

In [None]:
plt.plot(train_losses, label="train")
plt.plot(test_losses, label="test")
plt.ylim(0, max(train_losses + test_losses))
plt.legend();

In [None]:
model = GalaxyTransformer(
    num_classes=n_classes, embed_dim=128, num_heads=4, num_layers=3, dropout=0.025
).to(device)

In [None]:
model.load_state_dict(
    torch.load(f"{model_path}CNN_transformer_normal2_{sigma}_{offset}_epoch_70.pth")
)

In [None]:
model.eval()
correct = 0
total = 0

correct_train = 0
total_train = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()


print(f"Train Accuracy: {100 * correct_train / total_train:.2f}%")
print(f"Test Accuracy: {100 * correct / total:.2f}%")

## $\sigma = 0.0125$, `offset = 0.01`

In [None]:
plt.plot(train_losses, label="train")
plt.plot(test_losses, label="test")
plt.ylim(0, max(train_losses + test_losses))
plt.legend();

In [None]:
model_small.eval()
correct = 0
total = 0

correct_train = 0
total_train = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model_small(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model_small(images)
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()


print(f"Train Accuracy: {100 * correct_train / total_train:.2f}%")
print(f"Test Accuracy: {100 * correct / total:.2f}%")

# Pure Transformer

In [None]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=50, patch_size=5, in_chans=1, embed_dim=128):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        return x.transpose(1, 2)


class PureTransformerClassifier(nn.Module):
    def __init__(
        self,
        img_size=50,
        patch_size=5,
        in_chans=1,
        num_classes=7,
        embed_dim=128,
        depth=6,
        num_heads=4,
        mlp_ratio=4.0,
        dropout=0.1,
    ):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        # Define a learnable class token.
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # Define positional embeddings for patches + class token.
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=dropout)

        # Create Transformer encoder layers.
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            dropout=dropout,
            activation="gelu",
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.xavier_uniform_(self.head.weight)
        nn.init.constant_(self.head.bias, 0)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)

        # Prepend the class token to the patch embeddings.
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # Add positional embeddings.
        x = x + self.pos_embed
        x = self.pos_drop(x)

        x = x.transpose(0, 1)
        x = self.transformer_encoder(x)
        x = x.transpose(0, 1)

        x = self.norm(x)
        cls_output = x[:, 0]
        return self.head(cls_output)

In [None]:
pure_small = PureTransformerClassifier(
    img_size=50,
    patch_size=8,
    in_chans=1,
    num_classes=7,
    embed_dim=64,
    depth=3,
    num_heads=4,
    mlp_ratio=4.0,
    dropout=0.1,
).to(device)

In [None]:
train_losses, test_losses = train_model(
    pure_small,
    train_loader,
    test_loader,
    num_epochs=60,
    lr=5e-4,
    model_save_path=f"{model_path}pure_transformer_small_barsplit_post_{sigma}_{offset}",
)

In [None]:
pure_normal = PureTransformerClassifier(
    img_size=50,
    patch_size=4,
    in_chans=1,
    num_classes=7,
    embed_dim=128,
    depth=6,
    num_heads=4,
    mlp_ratio=4.0,
    dropout=0.05,
).to(device)

In [None]:
train_losses_pure_normal, test_losses_pure_normal = train_model(
    pure_normal,
    train_loader,
    test_loader,
    num_epochs=60,
    lr=5e-4,
    model_save_path=f"{model_path}pure_transformer_normal_barsplit_{sigma}_{offset}",
)

 ## Split watermark, post-noise $\sigma = 0.03$, `offset = 0.01`

In [None]:
plot_losses_and_evaluate(pure_small, train_loader, test_loader, train_losses, test_losses)

## Split watermark, $\sigma = 0.04$, `offset = 0.01`

In [None]:
plot_losses_and_evaluate(pure_small, train_loader, test_loader, train_losses, test_losses)

In [None]:
plot_losses_and_evaluate(
    pure_normal,
    train_loader,
    test_loader,
    train_losses_pure_normal,
    test_losses_pure_normal,
)