In [None]:
import torch_geometric
import torch

In [None]:
import tqdm
import os
import random

In [None]:
!gdown 1mTGV-I_oIKR0WWn22tkofmIhIEdD1q_m

In [None]:
!unzip airfoils_dataset.zip 

In [None]:
import torch_geometric
from torch_geometric.data import Data, DataLoader

import glob

import copy

import numpy as np

import os
import json 

import torch
from torch_geometric.data import Dataset, download_url, Batch, DataLoader
from torch_geometric.transforms import FaceToEdge

from torch_geometric.transforms import Cartesian, GenerateMeshNormals

class AirfoilsDataset(Dataset):
    
    def __init__(self, dir_path, split="train", samples_n=None, noise=False, zeros=False):
        
        super(AirfoilsDataset, self).__init__()
        
        self.noise = noise
        self.zeros = zeros
        
        self.cartesian_coords = Cartesian()
        self.normals = GenerateMeshNormals()
        
        self.dir_path = dir_path
        self.split = split
        fld_dir = os.path.join(dir_path, split, "*")
        print(fld_dir)
        self.filenames = glob.glob(fld_dir)
        self.samples = [self.generate_sample(f) for f in self.filenames]
            
    def generate_sample(self, filename):
        
        sample = json.load(open(filename))
        
        pos = torch.tensor(sample["coords"])
        edge_index = torch.tensor([[i, (i+1) % len(pos)] for i in range(len(pos))]).T
        drag = -sample["CL"] / sample["CD"]

        graph = Data(x=pos, pos=pos, edge_index=edge_index, y=drag)
        graph = Cartesian()(graph)
        
        return graph
        
    def len(self):
        return len(self.filenames)
    
    def get(self, idx):
        
        graph = copy.copy(self.samples[idx])
        
        if self.noise:
            N = 500 # 150
            vertices_number = graph.x.shape[0]
            eps = random.choice([0, 0, 0, 1]) * float(not self.zeros)
            y = torch.tensor(graph.y).tile([vertices_number, 1]) * eps + 100 * (1 - eps) # 150
            graph.x = torch.cat([graph.x, y / N], dim=1) 
#             graph.x = torch.cat([graph.x, 0.1 * torch.ones_like(y)], dim=1) 
        
            # get rid of me in case
#             graph.y = graph.y * (1 - eps) + 0 * eps
            
        return graph

In [None]:
path = "airfoils_dataset"
dataset = AirfoilsDataset(path, "train")

noise_dataset = AirfoilsDataset(path, "train", noise=True)

test_dataset = AirfoilsDataset(path, "test")

noise_test_dataset = AirfoilsDataset(path, "test", noise=True, zeros=True)

ood_dataset = AirfoilsDataset(path, "ood")

noise_ood_dataset = AirfoilsDataset(path, "ood", noise=True, zeros=True)

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, GMMConv

from typing import Callable, Optional, Union

import torch
from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.typing import (
    Adj,
    OptTensor,
    PairOptTensor,
    PairTensor,
    SparseTensor,
    torch_sparse,
)
from torch_geometric.utils import add_self_loops, remove_self_loops


class PointNetConv(MessagePassing):
    def __init__(self, local_nn: Optional[Callable] = None,
                 global_nn: Optional[Callable] = None,
                 add_self_loops: bool = True, **kwargs):
        kwargs.setdefault('aggr', 'max')
        super().__init__(**kwargs)

        self.local_nn = local_nn
        self.global_nn = global_nn
        self.add_self_loops = add_self_loops

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        reset(self.local_nn)
        reset(self.global_nn)

    def forward(
        self,
        x: Union[OptTensor, PairOptTensor],
        pos: Union[Tensor, PairTensor],
        edge_index: Adj,
    ) -> Tensor:

        if not isinstance(x, tuple):
            x = (x, None)

        if isinstance(pos, Tensor):
            pos = (pos, pos)

        if self.add_self_loops:
            if isinstance(edge_index, Tensor):
                edge_index, _ = remove_self_loops(edge_index)
                edge_index, _ = add_self_loops(
                    edge_index, num_nodes=min(pos[0].size(0), pos[1].size(0)))
            elif isinstance(edge_index, SparseTensor):
                edge_index = torch_sparse.set_diag(edge_index)

        # propagate_type: (x: PairOptTensor, pos: PairTensor)
        out = self.propagate(edge_index, x=x, pos=pos)

        if self.global_nn is not None:
            out = self.global_nn(out)

        return out

    def message(self, x_j: Optional[Tensor], pos_i: Tensor, pos_j: Tensor) -> Tensor:
        msg = pos_j - pos_i
        if x_j is not None:
            msg = torch.cat([x_j, msg], dim=1)
        if self.local_nn is not None:
            msg = self.local_nn(msg)
        return msg

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(local_nn={self.local_nn}, '
                f'global_nn={self.global_nn})')


import torch_geometric.nn
from torch.nn import Linear, ReLU, LeakyReLU, ELU, Dropout

from torch_geometric.nn import Sequential, GCNConv, GMMConv

activation = torch.nn.ReLU

def block(n_channels=12):

    return Sequential('x, edge_index, edge_attr',  [
        (GMMConv(n_channels, n_channels, 2, 2), 'x, edge_index, edge_attr -> x'),
        activation(inplace=True),
        (GMMConv(n_channels, 2*n_channels, 2, 2), 'x, edge_index, edge_attr -> x'),
        activation(inplace=True),
        (GMMConv(2*n_channels, 2*n_channels, 2, 2), 'x, edge_index, edge_attr -> x'),
        activation(inplace=True),
        (GMMConv(2*n_channels, 2*n_channels, 2, 2), 'x, edge_index, edge_attr -> x'),
        activation(inplace=True),
        (GMMConv(2*n_channels, n_channels, 2, 2), 'x, edge_index, edge_attr -> x'),
        activation(inplace=True),
    ])

class Net(torch.nn.Module):
    
    def __init__(self, channels=64, dropout=0., input_channels=2):
        
        super(Net, self).__init__()
        
        self.dropout = dropout
        
        CH = channels
        self.conv1 = GMMConv(input_channels, CH, 2, 2)
        
        self.block1 = block(CH)
        self.block2 = block(CH)
        self.block3 = block(CH)
        self.block4 = block(CH)
        self.block5 = block(CH)
        
        self.conv_final = GMMConv(CH, CH, 2, 2)
        
        self.fc1 = torch.nn.Linear(CH, CH)
        self.fc2 = torch.nn.Linear(CH, CH)
        self.fc3 = torch.nn.Linear(CH, CH)
        self.fc4 = torch.nn.Linear(CH, 1)
        self.activation = activation()

    def forward(self, data, features=False):
        
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        x = F.elu(self.conv1(x, edge_index, edge_attr))
        
        x = self.block1(x, edge_index, edge_attr)
        x = self.block2(x, edge_index, edge_attr)# + x
        x = self.block3(x, edge_index, edge_attr)# + x
        x = self.block4(x, edge_index, edge_attr)# + x
        x = self.block5(x, edge_index, edge_attr)# + x
        
        x = self.conv_final(x, edge_index, edge_attr)
        feat = global_mean_pool(x, batch)
        
        x = self.activation(self.fc1(feat))
        
        x = torch.nn.functional.dropout(self.activation(self.fc2(x)), p=self.dropout, training=True)
        x = torch.nn.functional.dropout(self.activation(self.fc3(x)), p=self.dropout, training=True)

        if features:
            return self.fc4(x), feat
        
        return self.fc4(x)

In [None]:
import os
os.makedirs("weights", exist_ok=True)

# Train Vanilla

In [None]:
import tqdm
import random

for r in [16]:

    lr = 1e-3
    epochs = 10
    device = "cuda"
    batch_size = 128

    random_seed = r
    torch.backends.cudnn.enabled = False
    torch.manual_seed(random_seed)

    torch.manual_seed(random_seed)
    random.seed(random_seed)

    model = Net(64).cuda()
    dataloader = DataLoader(dataset, batch_size)
    test_loader = DataLoader(test_dataset, 1)
    ood_loader = DataLoader(ood_dataset, 1)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    model.train()
    for epoch in range(epochs):

        res = []
        t = tqdm.trange(len(dataset) // batch_size + 1, desc='Current Loss = ', leave=True)

        for _, batch in zip(t, dataloader):
            batch = batch.to(device)
            optimizer.zero_grad()
            y = model(batch)

            loss = F.mse_loss(y, batch.y[:, None])

            metric = torch.abs(y[:, 0].detach().cpu() - batch.y.cpu()).mean().numpy()
            res.append(metric)

            loss.backward()
            optimizer.step()

            t.set_description(f"Current Loss = {sum(res) / len(res)}", refresh=True)

    for g in optimizer.param_groups:
        g['lr'] = 1e-4

    for epoch in range(epochs):

        res = []
        t = tqdm.trange(len(dataset) // batch_size + 1, desc='Current Loss = ', leave=True)

        for _, batch in zip(t, dataloader):
            batch = batch.to(device)
            optimizer.zero_grad()
            y = model(batch)

            loss = F.mse_loss(y, batch.y[:, None])

            metric = torch.abs(y[:, 0].detach().cpu() - batch.y.cpu()).mean().numpy()
            res.append(metric)

            loss.backward()
            optimizer.step()

            t.set_description(f"Current Loss = {sum(res) / len(res)}", refresh=True)

    torch.save(model.state_dict(), f"./weights/single_model_{random_seed}.pth")

# ZigZag Training

In [None]:
import tqdm
import os
import random

for r in range(16, 17):

    lr = 1e-3
    epochs = 10
    device = "cuda"
    batch_size = 128

    random_seed = r
    torch.use_deterministic_algorithms(False, warn_only=True)
    torch.backends.cudnn.enabled = False
    torch.manual_seed(random_seed)
    random.seed(random_seed)
    np.random.seed(0)

    #     model = Net(64, 0.0, 2).cuda()
    #     dataloader = DataLoader(dataset, batch_size)
    #     test_loader = DataLoader(test_dataset, 1)
    #     ood_loader = DataLoader(ood_dataset, 1)

    model = Net(64, 0.0, 3).cuda()
    dataloader = DataLoader(noise_dataset, batch_size)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    model.train()
    for epoch in range(epochs):

        res = []
        t = tqdm.trange(len(dataset) // batch_size + 1, desc='Current Loss = ', leave=True)

        for _, batch in zip(t, dataloader):
            batch = batch.to(device)
            optimizer.zero_grad()
            y = model(batch)

            loss = F.mse_loss(y, batch.y[:, None])

            mask = batch.y != 0
            metric = torch.abs(y[:, 0][mask].detach().cpu() - batch.y[mask].cpu()).mean().numpy()
            res.append(metric)

            loss.backward()
            optimizer.step()

            t.set_description(f"Current Loss = {sum(res) / len(res)}", refresh=True)

    for g in optimizer.param_groups:
        g['lr'] = 1e-4

    for epoch in range(epochs):

        res = []
        t = tqdm.trange(len(dataset) // batch_size + 1, desc='Current Loss = ', leave=True)

        for _, batch in zip(t, dataloader):
            batch = batch.to(device)
            optimizer.zero_grad()
            y = model(batch)

            loss = F.mse_loss(y, batch.y[:, None])

            mask = batch.y != 0
            metric = torch.abs(y[:, 0][mask].detach().cpu() - batch.y[mask].cpu()).mean().numpy()
            res.append(metric)

            loss.backward()
            optimizer.step()

            t.set_description(f"Current Loss = {sum(res) / len(res)}", refresh=True)

    torch.save(model.state_dict(), f"./weights/zigzag_{random_seed}.pth")

# Eval

In [None]:
import tqdm
import time 

def test(model, loader):
    
    t = tqdm.trange(len(loader), desc='Current Loss = ', leave=True)
    
    errors = []
    for _, batch in zip(t, loader):
        s = time.time()
        pred_y = model(batch.cuda())
        mae = float(torch.abs(pred_y.detach().cpu() - batch.y.cpu()).numpy())
        errors.append(mae)
        t.set_description(f"Current Loss = {sum(errors) / len(errors)}", refresh=True)

    return errors

## Vanilla

In [None]:
network = Net(64).cuda()

network.load_state_dict(torch.load(f"./weights/single_model_16.pth"))
network = network.eval()

In [None]:
errors = test(network, test_loader)
print("TEST SET MAE: ", np.mean(errors))

In [None]:
errors = test(network, ood_loader)
print("OOD SET MAE: ", np.mean(errors))

## ZigZag

In [None]:
znetwork = Net(64, 0.0, 3).cuda()
znetwork.load_state_dict(torch.load(f"./weights/zigzag_16.pth"))
znetwork = znetwork.eval()

In [None]:
noise_test_loader = DataLoader(noise_test_dataset, 1)

In [None]:
errors = test(znetwork, noise_test_loader)
print("TEST SET MAE: ", np.mean(errors))

In [None]:
errors = test(znetwork,  DataLoader(noise_ood_dataset, 1))
print("OOD SET MAE: ", np.mean(errors))

# ITTT Optimization

In [None]:
from copy import deepcopy
import torch.optim as optim
import time

def noise_uncertainty(model, model_original, data):
    y_1 = model(data.cuda())
    copy_data = deepcopy(data)
    batch_mask = data.batch
    
    for i in range(len(y_1)):
        copy_data.x[batch_mask == i, 2] = y_1[i] / 500.0
        
    y_2 = model_original(copy_data.cuda())
    return (y_1 - y_2).abs().mean(), y_2, y_1

def js_div(p, q):
    m = 0.5 * (p + q)
    return 0.5 * (F.kl_div(torch.log(p), m, reduction='batchmean') +
                  F.kl_div(torch.log(q), m, reduction='batchmean'))

def ttt_one_instance(x, f_ttt, f, optimizer, n_steps, n_classes=10):
  f_ttt.load_state_dict(f.state_dict())  # reset f_ttt to f
  f_ttt.train()
  f.eval()
  for step in range(n_steps):
    loss, y_1, y_2 = noise_uncertainty(f_ttt, f, x)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  f_ttt.eval()
  _, y_1, y_2 = noise_uncertainty(f_ttt, f, x)
  return y_1, y_2


def ttt(f, test_loader, n_steps, lr):
    f_ttt = deepcopy(f)
    f.eval()
    optimizer = optim.Adam(f_ttt.parameters(), lr=lr)

    t = tqdm.trange(len(test_loader), desc='Current Loss = ', leave=True)
    errors = []
    targets = []
    for _, batch in zip(t, test_loader):
        s = time.time()
        y_hat_1, y_hat_2 = ttt_one_instance(batch, f_ttt, f, optimizer, n_steps)
        mae = torch.abs(y_hat_2[:, 0].detach().cpu() - batch.y.cpu()).numpy().ravel().tolist()
        targets += batch.y.ravel().tolist()
        errors += mae #.append(mae)
        t.set_description(f"Current Loss = {sum(errors) / len(errors)}", refresh=True)

    return np.array(errors), np.array(targets)

def vanilla_eval(f, test_loader):
    f.eval()
    t = tqdm.trange(len(test_loader), desc='Current Loss = ', leave=True)
    errors = []
    targets = []
    for _, batch in zip(t, test_loader):
        s = time.time()
        batch = batch.cuda()
        y_hat = f(batch)
        mae = torch.abs(y_hat[:, 0].detach().cpu() - batch.y.cpu()).numpy().ravel().tolist()
        targets += batch.y.ravel().tolist()
        errors += mae #.append(mae)
        t.set_description(f"Current Loss = {sum(errors) / len(errors)}", refresh=True)

    return np.array(errors), np.array(targets)

# Vanilla OOD Eval

In [None]:
errors, targets = vanilla_eval(network, ood_loader)

errs = []
for i in range(5):
    q1, q2 = np.quantile(targets, q=0.2 * i), np.quantile(targets, q=0.2 * (i + 1))
    mask = np.logical_and(targets >= q1, targets <= q2)
    err = errors[mask].tolist()
    errs.append(err)

NOT_OPTIMIZED = [s for t in errs[::-1][:-1] for s in t]

# ITTT OOD Eval

In [None]:
ittt_results = {}
for bs in [2, 4, 16]:

    errors, targets = ttt(znetwork, DataLoader(noise_ood_dataset, bs), n_steps=1, lr=2e-4) # PROPER SETUP FROM PAPER
    
    errs = []
    for i in range(5):
        q1, q2 = np.quantile(targets, q=0.2 * i), np.quantile(targets, q=0.2 * (i + 1))
        mask = np.logical_and(targets >= q1, targets <= q2)
        err = errors[mask].tolist()
        errs.append(err)
    
    ittt_results[bs] = [s for t in errs[::-1][:-1] for s in t]

# ActMAD OOD Eval

In [None]:
from torch_ttt.engine.actmad_engine import ActMADEngine

In [None]:
engine = ActMADEngine(
    network,
    [
        "block1.module_8.aggr_module",
    ],
    optimization_parameters={
        "lr": 2e-4,
        "num_steps": 1
    }
)

train_dataloader = DataLoader(dataset, 16)
engine.compute_statistics(train_dataloader)

In [None]:
actmad_results = {}

for bs in [2, 4, 16]:
    test_loader = DataLoader(ood_dataset, bs)
    engine.eval()
    t = tqdm.trange(len(test_loader), desc='Current Loss = ', leave=True)
    errors = []
    targets = []
    for _, batch in zip(t, test_loader):
        s = time.time()
        batch = batch.cuda()
        y_hat = engine(batch)
        mae = torch.abs(y_hat[0][:, 0].detach().cpu() - batch.y.cpu()).numpy().ravel().tolist()
        targets += batch.y.ravel().tolist()
        errors += mae
        t.set_description(f"Current Loss = {sum(errors) / len(errors)}", refresh=True)

    targets, errors = np.array(targets), np.array(errors)

    errs = []
    for i in range(5):
        q1, q2 = np.quantile(targets, q=0.2 * i), np.quantile(targets, q=0.2 * (i + 1))
        mask = np.logical_and(targets >= q1, targets <= q2)
        err = errors[mask].tolist()
        errs.append(err)

    actmad_results[bs] = [s for t in errs[::-1][:-1] for s in t]

In [None]:
import numpy as np

MAE_2_batch = ittt_results[2]

MAE_4_batch = ittt_results[4]

MAE_16_batch = ittt_results[16]

ActMAD_16_batch = actmad_results[16]

ActMAD_4_batch = actmad_results[4]

ActMAD_2_batch = actmad_results[2] 

In [None]:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
# Example extended data (more data points for each combination)

fontsize = 20

data = {
    'Method': ['Not Optimized'] * len(NOT_OPTIMIZED) + ['ActMAD (batch=2)'] * len(ActMAD_2_batch) + ['ActMAD (batch=4)'] * len(ActMAD_4_batch) + ['ActMAD (batch=16)'] * len(ActMAD_16_batch) + ['$IT^3$ (batch=2)'] * len(MAE_2_batch) + ['$IT^3$ (batch=4)'] * len(MAE_4_batch) + ["$IT^3$ (batch=16)"] * len(MAE_16_batch),
    'Out-of-distribution Level': ([1] * 58 + [2] * 57 + [3] * 58 + [4] * 57) * 7,
    'Accuracy': NOT_OPTIMIZED + ActMAD_2_batch + ActMAD_4_batch + ActMAD_16_batch + MAE_2_batch + MAE_4_batch + MAE_16_batch
}
# Convert the dictionary to a Pandas DataFrame
df = pd.DataFrame(data)
# Define custom colors for each method
custom_palette = {
    'Not Optimized': '#96CAC1',      # Green
    'ActMAD (batch=2)': '#022db8e3',
    'ActMAD (batch=4)': '#2647b5e3',
    'ActMAD (batch=16)': '#475fade3',
    '$IT^3$ (batch=2)': '#5304a8e3', # Yellow
    '$IT^3$ (batch=4)': '#69349ee3',   # Violet
    '$IT^3$ (batch=16)': '#7e5ca1e3'   # Red
}
# Create the seaborn boxplot with custom colors
plt.figure(figsize=(14, 8))
sns.boxplot(x='Out-of-distribution Level', y='Accuracy', hue='Method', data=df, palette=custom_palette, showfliers=False)
# Customize the plot
# plt.title('Accuracy by Method and Severity', fontsize=fontsize)
plt.xlabel('Out-of-distribution Level', fontsize=fontsize)
plt.ylabel('MAE', fontsize=fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.legend(fontsize=fontsize, ncol=3, framealpha=0.4)

plt.grid(True, axis='y', linestyle='--') 
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_visible(False)
plt.gca().spines['bottom'].set_visible(False)

plt.tight_layout()

# Show plot
# plt.show()
# plt.savefig("./viz/airfoils_ittt_results_box_updated.pdf", format="pdf")