# X CLIP loss reimplementation

Inspired by https://github.com/lucidrains/x-clip/blob/main/x_clip/x_clip.py

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging
import os
import os.path as osp
from collections import defaultdict
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable, List, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from lightning.pytorch import LightningModule
from lion_pytorch import Lion
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import (
    Compose,
    Lambda,
    Normalize,
    RandomCrop,
    RandomHorizontalFlip,
    RandomResizedCrop,
    ToTensor,
)
from torchvision.transforms.functional import InterpolationMode
from transformers import (
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    ViTImageProcessor,
    ViTMAEConfig,
    ViTMAEForPreTraining,
)

from src.mae.module import MAEDatasetConfig, MAEModule, MAEOptimizerConfig
from src.modules.transforms import ComplexTransform, SimpleTransform
from src.modules.transforms.color_jitter import ColorJitterPerChannel

In [None]:
for i in range(1, 4):
    if not Path(f"../cpjump{i}/jump/").exists():
        print(f"Mounting cpjump{i}...")
        os.system(f"sshfs bioclust:/projects/cpjump{i}/ ../cpjump{i}")
    else:
        print(f"cpjump{i} already mounted.")

Mounting cpjump1...
Mounting cpjump2...
Mounting cpjump3...


## Contrastive loss

In [3]:
def matrix_diag(t):
    device = t.device
    i, j = t.shape[-2:]
    num_diag_el = min(i, j)
    i_range = torch.arange(i, device=device)
    j_range = torch.arange(j, device=device)
    diag_mask = rearrange(i_range, "i -> i 1") == rearrange(j_range, "j -> 1 j")
    diag_el = t.masked_select(diag_mask)
    return rearrange(diag_el, "(b d) -> b d", d=num_diag_el)

In [10]:
t = torch.randn(4, 3, 1, 2, 2)
t[0]

tensor([[[[ 1.3603,  1.5134],
          [ 0.2363,  0.9996]]],


        [[[ 0.3233,  0.9008],
          [ 1.5954, -0.4311]]],


        [[[-0.0736, -0.4979],
          [ 0.0264, -0.5729]]]])

In [16]:
matrix_diag(t)

tensor([[ 1.3603,  0.9996],
        [ 0.3233, -0.4311],
        [-0.0736, -0.5729],
        [ 0.1492,  1.0441],
        [ 1.1774,  1.4675],
        [ 0.2147, -0.7641],
        [ 1.3447,  0.3734],
        [ 1.5712,  1.1409],
        [-1.6532,  0.5066],
        [ 2.3777, -0.3648],
        [ 0.5140, -0.1949],
        [-0.5391,  0.4937]])