# Example of using augmentation to do contrastive learning

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

## Import packages and load data

In [1]:
import sys
sys.path.insert(1, '../src/')


from data import PhylogenyTree, PhylogenyDataset
from ete4 import Tree
from mixup import Mixup
dataset = "alzbiom"
target = "ad"

data_fp = f'../data/{dataset}/data.tsv.xz'
meta_fp = f'../data/{dataset}/meta.tsv'
target_fp = f'../data/{dataset}/{target}.py'

phylogeny_tree_fp = '../data/WoL2/phylogeny.nwk'
tree = PhylogenyTree.init_from_nwk(phylogeny_tree_fp)
data = PhylogenyDataset.init_from_files(data_fp, meta_fp, target_fp)
tree.prune(data.features)
print("number of leaves in the phylogeny tree after pruning: ", len(list(tree.ete_tree.leaves())))

number of leaves in the phylogeny tree after pruning:  8350


## Get augmented data

Contrastive learning requires a pair of augmented data, and we can using phylomix to generate them by specifying different random seed

In [3]:
from mixup import Mixup
import numpy as np

data.one_hot_encode()
mixup = Mixup(data, taxonomy_tree=None, phylogeny_tree=tree)
idx1 = np.arange(len(data.X))
idx2 = idx1[::-1]
out1 = mixup.mixup(len(idx1), method='phylomix', alpha=2, tree='phylogeny', index1=idx1, index2=idx2, contrastive_learning=True, seed=0)
out2 = mixup.mixup(len(idx1), method='phylomix', alpha=2, tree='phylogeny', index1=idx1, index2=idx2, contrastive_learning=True, seed=1)

## Define a simple model as encoder

In [4]:
class EncoderProjectionHead(nn.Module):
    def __init__(self, input_dim, hidden_dim1=1024, hidden_dim2=256, latent_dim=512, output_dim=512):
        super(EncoderProjectionHead, self).__init__()
        self.fc1 = nn.Linear(input_dim, input_dim // 2)
        self.fc2 = nn.Linear(input_dim // 2, hidden_dim1)
        self.fc3 = nn.Linear(hidden_dim1, output_dim)
        self.encoder = nn.Sequential(
            self.fc1,
            nn.ReLU(),
            self.fc2,
            nn.ReLU(),
            self.fc3
        )
        self.projection_head = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128)
        )

    def forward(self, x):
        z = self.encoder(x)
        z = self.projection_head(z)
        return F.normalize(z, dim=1)  # Normalize the output

## Define the simclr loss function

In [5]:
def info_nce_loss(features, batch_size, device):

    labels = torch.cat([torch.arange(batch_size) for i in range(2)], dim=0).to(device)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
    labels = labels

    features = F.normalize(features, dim=1)

    similarity_matrix = torch.matmul(features, features.T)

    # discard the main diagonal from both: labels and similarities matrix
    mask = torch.eye(labels.shape[0], dtype=torch.bool)
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long)

    logits = logits / 0.05
    return logits.to(device), labels.to(device)

## Make data into data loader

In [9]:
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
combined_data = np.stack((out1, out2), axis=1)
combined_data = torch.tensor(combined_data, dtype=torch.float32).to(device)
dataloader = DataLoader(combined_data, batch_size=combined_data.shape[0], shuffle=True)

## Train the model

In [19]:
from tqdm import tqdm

input_dim = data.X.shape[-1]
encoder = EncoderProjectionHead(input_dim).to(device)
epochs = 1000
optimizer = optim.Adam(encoder.parameters(), lr=0.001)
batch_size = dataloader.batch_size
criterion = nn.CrossEntropyLoss()


for epoch in range(epochs):
    epoch_loss = 0

    with tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") as pbar:
        for d in pbar:
            d = d.to(device).view(-1, d.size()[-1])
            optimizer.zero_grad()
            z = encoder(d)
            logits, labels = info_nce_loss(z, batch_size, device)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            pbar.set_description(f"Epoch {epoch+1}/{epochs} - Loss: {loss.item():.6f}")

    avg_loss = epoch_loss / len(dataloader)



Epoch 1/1000 - Loss: 14.655119: 100%|██████████| 1/1 [00:00<00:00, 19.44it/s]
Epoch 2/1000 - Loss: 18.181810: 100%|██████████| 1/1 [00:00<00:00, 14.36it/s]
Epoch 3/1000 - Loss: 18.576300: 100%|██████████| 1/1 [00:00<00:00, 12.18it/s]
Epoch 4/1000 - Loss: 17.687069: 100%|██████████| 1/1 [00:00<00:00, 14.31it/s]
Epoch 5/1000 - Loss: 17.396997: 100%|██████████| 1/1 [00:00<00:00, 17.22it/s]
Epoch 6/1000 - Loss: 17.497082: 100%|██████████| 1/1 [00:00<00:00, 38.68it/s]
Epoch 7/1000 - Loss: 16.841011: 100%|██████████| 1/1 [00:00<00:00, 26.53it/s]
Epoch 8/1000 - Loss: 17.968037: 100%|██████████| 1/1 [00:00<00:00, 17.50it/s]
Epoch 9/1000 - Loss: 15.996158: 100%|██████████| 1/1 [00:00<00:00, 16.39it/s]
Epoch 10/1000 - Loss: 17.054409: 100%|██████████| 1/1 [00:00<00:00, 11.97it/s]
Epoch 11/1000 - Loss: 16.897173: 100%|██████████| 1/1 [00:00<00:00, 36.44it/s]
Epoch 12/1000 - Loss: 16.158663: 100%|██████████| 1/1 [00:00<00:00, 20.88it/s]
Epoch 13/1000 - Loss: 16.355505: 100%|██████████| 1/1 [00:00<

## Get the embedding

In [27]:
encoder.eval()
with torch.no_grad():
    feature_extractor = encoder.encoder
    embedding = feature_extractor(torch.tensor(data.X).to(device)).cpu().numpy()
embedding.shape

(175, 512)

## We can use the embedding for downstream task

In [31]:
from sklearn.linear_model import LogisticRegressionCV
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import accuracy_score, roc_auc_score
import warnings

warnings.filterwarnings("ignore")
y = np.argmax(data.y, axis=1)
logistic_cv = LogisticRegressionCV(cv=5, max_iter=1000, random_state=42)

y_pred_cv = cross_val_predict(logistic_cv, embedding, y, cv=50, method="predict")

auroc = roc_auc_score(y, y_pred_cv)

accuracy = accuracy_score(y, y_pred_cv)

print(f"Cross-validated Accuracy: {accuracy:.4f}")
print(f"Cross-validated AUROC: {auroc:.4f}")

Cross-validated Accuracy: 0.5543
Cross-validated AUROC: 0.5400
