<a href="https://colab.research.google.com/github/julballa/LatentLattice/blob/main/escnn_autoencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install dependencies

In [102]:
from google.colab import drive
drive.mount('/content/drive')

# set current directory
# this should be the Google Drive folder where your file(s) are located
%cd /content/drive/MyDrive/lattices

## verify current directory
!ls /content/drive/MyDrive/lattices

# choose where you want your project files to be saved
project_folder = "/content/drive/MyDrive/lattices"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/lattices
12x12_train.pt	12x12_val.pt  ten_square_test.pt  ten_square_train.pt  ten_square_val.pt


In [9]:
!pip install git+https://github.com/AMLab-Amsterdam/lie_learn escnn

Collecting git+https://github.com/AMLab-Amsterdam/lie_learn
  Cloning https://github.com/AMLab-Amsterdam/lie_learn to /tmp/pip-req-build-hppoyc09
  Running command git clone --filter=blob:none --quiet https://github.com/AMLab-Amsterdam/lie_learn /tmp/pip-req-build-hppoyc09
  Resolved https://github.com/AMLab-Amsterdam/lie_learn to commit 1ccc2106e402d517a29de5438c9367c959e67338
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


# Define model

In [129]:
import torch.nn as nn
from escnn.nn import R3Conv, R3ConvTransposed, LeakyReLU, GroupPooling, GeometricTensor, FieldType
from escnn.gspaces.r3 import flipRot3dOnR3
# 2D data imports
from escnn.nn import R2Conv, R2ConvTransposed
from escnn.gspaces import rot2dOnR2

# Encoder -- 2 layer network
class Encoder(nn.Module):
    def __init__(self, in_type, out_type):
        super().__init__()
        self.conv1 = R2Conv(in_type, out_type, kernel_size=3, stride=1)
        self.act1 = LeakyReLU(out_type)
        self.conv2 = R2Conv(out_type, out_type, kernel_size=3, stride=1)
        self.act2 = LeakyReLU(out_type)
        self.pool = GroupPooling(out_type)


    def forward(self, x):
        x = self.conv1(x)
        x = self.act1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.act2(x)
        return x

# Decoder
class Decoder(nn.Module):
    def __init__(self, in_type, out_type):
        super().__init__()
        self.conv1 = R2ConvTransposed(in_type, out_type, kernel_size=3, stride=1)
        self.act1 = LeakyReLU(out_type)
        self.conv2 = R2ConvTransposed(out_type, out_type, kernel_size=3, stride=1)

    def forward(self, x):
        x = self.conv1(x)
        self.act1 = LeakyReLU(out_type)
        x = self.conv2(x)
        return x

class AutoEncoder(nn.Module):
    def __init__(self, in_type, out_type, latent_type):
        super().__init__()
        self.in_type = in_type
        self.out_type = out_type
        self.latent_type = latent_type

        self.encoder = Encoder(in_type, latent_type)
        self.decoder = Decoder(latent_type, out_type)

    def forward(self, x):
        x = GeometricTensor(x, self.in_type)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat.tensor

In [160]:
import torch.nn as nn
from escnn.nn import R2Conv, R2ConvTransposed, LeakyReLU, GroupPooling, GeometricTensor, FieldType, IIDBatchNorm2d
from escnn.gspaces import rot2dOnR2

# Encoder -- 2 layer with Batch Norm
class Encoder(nn.Module):
    def __init__(self, in_type, out_type):
        super().__init__()
        self.conv1 = R2Conv(in_type, out_type, kernel_size=3, stride=1)
        self.bn1 = IIDBatchNorm2d(out_type)
        self.act1 = LeakyReLU(out_type)
        self.conv2 = R2Conv(out_type, out_type, kernel_size=3, stride=1)
        self.bn2 = IIDBatchNorm2d(out_type)
        self.act2 = LeakyReLU(out_type)
        self.pool = GroupPooling(out_type)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act2(x)
        return x

# Decoder
class Decoder(nn.Module):
    def __init__(self, in_type, out_type):
        super().__init__()
        self.conv1 = R2ConvTransposed(in_type, out_type, kernel_size=3, stride=1)
        self.bn1 = IIDBatchNorm2d(out_type)
        self.act1 = LeakyReLU(out_type)
        self.conv2 = R2ConvTransposed(out_type, out_type, kernel_size=3, stride=1)
        self.bn2 = IIDBatchNorm2d(out_type)
        self.act2 = LeakyReLU(out_type)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act2(x)
        return x

class AutoEncoder(nn.Module):
    def __init__(self, in_type, out_type, latent_type):
        super().__init__()
        self.in_type = in_type
        self.out_type = out_type
        self.latent_type = latent_type

        self.encoder = Encoder(in_type, latent_type)
        self.decoder = Decoder(latent_type, out_type)

    def forward(self, x):
        x = GeometricTensor(x, self.in_type)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat.tensor


In [150]:
# import torch.nn as nn
# from escnn.nn import R2Conv, R2ConvTransposed, LeakyReLU, GroupPooling, GeometricTensor, FieldType
# from escnn.gspaces import rot2dOnR2

# # Encoder -- 3 layer network
# class Encoder(nn.Module):
#     def __init__(self, in_type, out_type):
#         super().__init__()
#         self.conv1 = R2Conv(in_type, out_type, kernel_size=3, stride=1)
#         self.act1 = LeakyReLU(out_type)
#         self.conv2 = R2Conv(out_type, out_type, kernel_size=3, stride=1)
#         self.act2 = LeakyReLU(out_type)
#         self.conv3 = R2Conv(out_type, out_type, kernel_size=3, stride=1)  # Additional layer
#         self.act3 = LeakyReLU(out_type)
#         self.pool = GroupPooling(out_type)

#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.act1(x)
#         x = self.pool(x)
#         x = self.conv2(x)
#         x = self.act2(x)
#         x = self.pool(x)
#         x = self.conv3(x)
#         x = self.act3(x)

#         return x

# # Decoder
# class Decoder(nn.Module):
#     def __init__(self, in_type, out_type):
#         super().__init__()
#         self.conv1 = R2ConvTransposed(in_type, out_type, kernel_size=3, stride=1)
#         self.act1 = LeakyReLU(out_type)
#         self.conv2 = R2ConvTransposed(out_type, out_type, kernel_size=3, stride=1)
#         self.act2 = LeakyReLU(out_type)  # Additional activation
#         self.conv3 = R2ConvTransposed(out_type, out_type, kernel_size=3, stride=1)  # Additional layer


#     def forward(self, x):
#         x = self.conv1(x)
#         self.act1 = LeakyReLU(out_type)
#         x = self.conv2(x)
#         self.act2 = LeakyReLU(out_type)
#         x = self.conv3(x) # Additional layer processing
#         return x

# class AutoEncoder(nn.Module):
#     def __init__(self, in_type, out_type, latent_type):
#         super().__init__()
#         self.in_type = in_type
#         self.out_type = out_type
#         self.latent_type = latent_type

#         self.encoder = Encoder(in_type, latent_type)
#         self.decoder = Decoder(latent_type, out_type)

#     def forward(self, x):
#         x = GeometricTensor(x, self.in_type)
#         z = self.encoder(x)
#         x_hat = self.decoder(z)
#         return x_hat.tensor


In [156]:
# import torch.nn as nn
# from escnn.nn import R2Conv, R2ConvTransposed, LeakyReLU, GroupPooling, GeometricTensor, FieldType, IIDBatchNorm2d
# from escnn.gspaces import rot2dOnR2

# # Encoder -- 3 layer network with BatchNorm
# class Encoder(nn.Module):
#     def __init__(self, in_type, out_type):
#         super().__init__()
#         self.conv1 = R2Conv(in_type, out_type, kernel_size=3, stride=1)
#         self.bn1 = IIDBatchNorm2d(out_type)
#         self.act1 = LeakyReLU(out_type)
#         self.conv2 = R2Conv(out_type, out_type, kernel_size=3, stride=1)
#         self.bn2 = IIDBatchNorm2d(out_type)
#         self.act2 = LeakyReLU(out_type)
#         self.conv3 = R2Conv(out_type, out_type, kernel_size=3, stride=1)
#         self.bn3 = IIDBatchNorm2d(out_type)
#         self.act3 = LeakyReLU(out_type)
#         self.pool = GroupPooling(out_type)

#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.bn1(x)
#         x = self.act1(x)
#         x = self.conv2(x)
#         x = self.bn2(x)
#         x = self.act2(x)
#         x = self.conv3(x)
#         x = self.bn3(x)
#         x = self.act3(x)
#         x = self.pool(x)
#         return x

# # Decoder
# class Decoder(nn.Module):
#     def __init__(self, in_type, out_type):
#         super().__init__()
#         self.conv1 = R2ConvTransposed(in_type, out_type, kernel_size=3, stride=1)
#         self.bn1 = IIDBatchNorm2d(out_type)
#         self.act1 = LeakyReLU(out_type)
#         self.conv2 = R2ConvTransposed(out_type, out_type, kernel_size=3, stride=1)
#         self.bn2 = IIDBatchNorm2d(out_type)
#         self.act2 = LeakyReLU(out_type)
#         self.conv3 = R2ConvTransposed(out_type, out_type, kernel_size=3, stride=1)
#         self.bn3 = IIDBatchNorm2d(out_type)
#         self.act3 = LeakyReLU(out_type)

#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.bn1(x)
#         x = self.act1(x)
#         x = self.conv2(x)
#         x = self.bn2(x)
#         x = self.act2(x)
#         x = self.conv3(x)
#         x = self.bn3(x)
#         x = self.act3(x)
#         return x

# class AutoEncoder(nn.Module):
#     def __init__(self, in_type, out_type, latent_type):
#         super().__init__()
#         self.in_type = in_type
#         self.out_type = out_type
#         self.latent_type = latent_type

#         self.encoder = Encoder(in_type, latent_type)
#         self.decoder = Decoder(latent_type, out_type)

#     def forward(self, x):
#         x = GeometricTensor(x, self.in_type)
#         z = self.encoder(x)
#         x_hat = self.decoder(z)
#         return x_hat.tensor


# Load dataset

In [22]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.5.3-py3-none-any.whl (1.1 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.2/1.1 MB[0m [31m5.4 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━[0m [32m0.6/1.1 MB[0m [31m9.2 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.1/1.1 MB[0m [31m10.8 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.5.3


In [116]:
from torch.utils.data import Dataset, DataLoader
import torch

class LatticeDataset(Dataset):
    def __init__(self, data_list):
        """
        data_list is a list of dictionaries, each containing keys 'x', 'edge_index', and 'coords'.
        'x' holds the features of the nodes, which are one-hot encoded.
        """
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        # Get the data point
        data = self.data_list[idx]

        # The node features are one-hot encoded and should be reshaped into a 12x12 grid
        node_features = data.x.view(1, 12, 12).float()  # Assuming the features are flat [144, 1]

        # Since this is an autoencoder, the input is the target
        return node_features, node_features

train_data = torch.load('12x12_train.pt')
val_data = torch.load('12x12_val.pt')
# test_data = torch.load('ten_square_test.pt')

train_dataset = LatticeDataset(train_data)
val_dataset = LatticeDataset(val_data)
# test_dataset = LatticeDataset(test_data)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)



# Train model

In [164]:
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on device: {device}")

# Init model
in_type = FieldType(rot2dOnR2(), [rot2dOnR2().trivial_repr])
latent_type = FieldType(rot2dOnR2(), [rot2dOnR2().trivial_repr])
out_type = FieldType(rot2dOnR2(), [rot2dOnR2().trivial_repr]*2)

model = AutoEncoder(in_type, out_type, latent_type)

model = model.to(device)
criterion = nn.CrossEntropyLoss()  # Loss function
optimizer = optim.Adam(model.parameters(), lr=0.05)  # Optimizer

# Define learning rate scheduler
scheduler = StepLR(optimizer, step_size=40, gamma=0.1)

num_epochs = 500
for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    train_loss = 0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device).squeeze().long()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.size(0)

    train_loss /= len(train_loader.dataset)

    # Update learning rate
    scheduler.step()

    # Validation phase
    val_loss = 0
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device).squeeze().long()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item() * inputs.size(0)

    val_loss /= len(val_loader.dataset)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")


Training on device: cpu
Epoch 1/500, Train Loss: 0.5114, Validation Loss: 1.5878
Epoch 2/500, Train Loss: 0.4323, Validation Loss: 0.4230
Epoch 3/500, Train Loss: 0.4016, Validation Loss: 0.6046
Epoch 4/500, Train Loss: 0.3862, Validation Loss: 0.4734
Epoch 5/500, Train Loss: 0.3822, Validation Loss: 0.3928
Epoch 6/500, Train Loss: 0.3800, Validation Loss: 0.3958
Epoch 7/500, Train Loss: 0.3788, Validation Loss: 0.3809
Epoch 8/500, Train Loss: 0.3774, Validation Loss: 0.3860
Epoch 9/500, Train Loss: 0.3763, Validation Loss: 0.3799
Epoch 10/500, Train Loss: 0.3755, Validation Loss: 0.3911
Epoch 11/500, Train Loss: 0.3748, Validation Loss: 0.3759
Epoch 12/500, Train Loss: 0.3740, Validation Loss: 0.3739
Epoch 13/500, Train Loss: 0.3734, Validation Loss: 0.4031
Epoch 14/500, Train Loss: 0.3728, Validation Loss: 0.3780
Epoch 15/500, Train Loss: 0.3723, Validation Loss: 0.3717
Epoch 16/500, Train Loss: 0.3717, Validation Loss: 0.3793
Epoch 17/500, Train Loss: 0.3714, Validation Loss: 0.3764

KeyboardInterrupt: 