# Install dependencies

In [1]:
from google.colab import drive
drive.mount('/content/drive')

# set current directory
# this should be the Google Drive folder where your file(s) are located
%cd /content/drive/MyDrive/lattices

## verify current directory
!ls /content/drive/MyDrive/lattices

# choose where you want your project files to be saved
project_folder = "/content/drive/MyDrive/lattices"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/lattices
20240504_2031  20240505_0042  20240505_0120  20240505_0148  20240505_0244  20240505_0249
20240505_0005  20240505_0045  20240505_0135  20240505_0214  20240505_0245  data
20240505_0013  20240505_0049  20240505_0136  20240505_0217  20240505_0246  LatentLattice
20240505_0040  20240505_0114  20240505_0137  20240505_0218  20240505_0247
20240505_0041  20240505_0119  20240505_0140  20240505_0233  20240505_0248


In [2]:
!pip install git+https://github.com/AMLab-Amsterdam/lie_learn escnn

Collecting git+https://github.com/AMLab-Amsterdam/lie_learn
  Cloning https://github.com/AMLab-Amsterdam/lie_learn to /tmp/pip-req-build-19rcal8p
  Running command git clone --filter=blob:none --quiet https://github.com/AMLab-Amsterdam/lie_learn /tmp/pip-req-build-19rcal8p
  Resolved https://github.com/AMLab-Amsterdam/lie_learn to commit 1ccc2106e402d517a29de5438c9367c959e67338
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting escnn
  Downloading escnn-1.0.11-py3-none-any.whl (373 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m373.9/373.9 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
Collecting pymanopt (from escnn)
  Downloading pymanopt-2.2.0-py3-none-any.whl (71 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.8/71.8 kB[0m [31m7.0 MB/s[0m eta [36m0:00:0

# Define model

In [95]:
import torch.nn as nn
from escnn.nn import R3Conv, R3ConvTransposed, LeakyReLU, GroupPooling, GeometricTensor, FieldType
from escnn.gspaces.r3 import flipRot3dOnR3

# Encoder
class Encoder(nn.Module):
    def __init__(self, in_type, out_type):
        super().__init__()
        self.conv1 = R3Conv(in_type, out_type, kernel_size=3, stride=1)
        self.act1 = LeakyReLU(out_type)
        self.conv2 = R3Conv(out_type, out_type, kernel_size=3, stride=1)
        self.act2 = LeakyReLU(out_type)
        self.pool = GroupPooling(out_type)


    def forward(self, x):
        x = self.conv1(x)
        x = self.act1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.act2(x)
        return x

# Decoder
class Decoder(nn.Module):
    def __init__(self, in_type, out_type):
        super().__init__()
        self.conv1 = R3ConvTransposed(in_type, out_type, kernel_size=3, stride=1)
        self.act1 = LeakyReLU(out_type)
        self.conv2 = R3ConvTransposed(out_type, out_type, kernel_size=3, stride=1)

    def forward(self, x):
        x = self.conv1(x)
        self.act1 = LeakyReLU(out_type)
        x = self.conv2(x)
        return x

class AutoEncoder(nn.Module):
    def __init__(self, in_type, out_type, latent_type):
        super().__init__()
        self.in_type = in_type
        self.out_type = out_type
        self.latent_type = latent_type

        self.encoder = Encoder(in_type, latent_type)
        self.decoder = Decoder(latent_type, out_type)

    def forward(self, x):
        x = GeometricTensor(x, self.in_type)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat.tensor

# Load dataset

In [96]:
import torch
from torch.utils.data import Dataset, DataLoader

class LatticeDataset(Dataset):
    def __init__(self, shape, length, num_node_types=2):
        self.length = length
        self.data = [torch.randint(0, num_node_types, shape).float() for _ in range(length)]  # 1 channel, size x size lattices

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.data[idx], self.data[idx]

shape = (1,8,8,8)
train_dataset = LatticeDataset(shape=shape, length=1000, num_node_types=2)
val_dataset = LatticeDataset(shape=shape, length=200, num_node_types=2)
test_dataset = LatticeDataset(shape=shape, length=200, num_node_types=2)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Train model

In [100]:
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on device: {device}")

# Init model
in_type = FieldType(flipRot3dOnR3(), [flipRot3dOnR3().trivial_repr])
latent_type = FieldType(flipRot3dOnR3(), [flipRot3dOnR3().trivial_repr])
out_type = FieldType(flipRot3dOnR3(), [flipRot3dOnR3().trivial_repr]*2)

model = AutoEncoder(in_type, out_type, latent_type)

model = model.to(device)
criterion = nn.CrossEntropyLoss()  # Loss function
optimizer = optim.Adam(model.parameters(), lr=0.01)  # Optimizer

num_epochs = 500
for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    train_loss = 0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device).squeeze().long()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.size(0)

    train_loss /= len(train_loader.dataset)

    # Validation phase
    val_loss = 0
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device).squeeze().long()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item() * inputs.size(0)

    val_loss /= len(val_loader.dataset)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")


Training on device: cpu
Epoch 1/500, Train Loss: 0.6931, Validation Loss: 0.6931
Epoch 2/500, Train Loss: 0.6930, Validation Loss: 0.6929
Epoch 3/500, Train Loss: 0.6929, Validation Loss: 0.6926
Epoch 4/500, Train Loss: 0.6921, Validation Loss: 0.6920
Epoch 5/500, Train Loss: 0.6900, Validation Loss: 0.6890
Epoch 6/500, Train Loss: 0.6883, Validation Loss: 0.6866
Epoch 7/500, Train Loss: 0.6861, Validation Loss: 0.6849
Epoch 8/500, Train Loss: 0.6851, Validation Loss: 0.6840
Epoch 9/500, Train Loss: 0.6841, Validation Loss: 0.6823
Epoch 10/500, Train Loss: 0.6820, Validation Loss: 0.6791
Epoch 11/500, Train Loss: 0.6767, Validation Loss: 0.6701
Epoch 12/500, Train Loss: 0.6621, Validation Loss: 0.6484
Epoch 13/500, Train Loss: 0.6394, Validation Loss: 0.6313
Epoch 14/500, Train Loss: 0.6296, Validation Loss: 0.6279
Epoch 15/500, Train Loss: 0.6276, Validation Loss: 0.6269
Epoch 16/500, Train Loss: 0.6267, Validation Loss: 0.6262
Epoch 17/500, Train Loss: 0.6261, Validation Loss: 0.6257

KeyboardInterrupt: 