<a href="https://colab.research.google.com/github/eisbetterthanpi/pytorch/blob/main/JEPA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install optuna


In [None]:

# https://openreview.net/pdf?id=BZ5a1r-kVsf
# import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import collections
device = "cuda" if torch.cuda.is_available() else "cpu"
import optuna

def off_diagonal(x):
    print("off_diagonal",x.shape)
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


class JEPA(nn.Module):
    # def __init__(self, xin_channels, dim_sx, dim_sy, dim_z, dim_v, n_actions, space_dims, hidden_dims):
    def __init__(self, xin_channels, dim_sx, dim_sy, dim_z, dim_v):
        super(JEPA, self).__init__()
        self.enc_x = nn.Sequential( # embed pi (240, 256, 3) -> 256 when flattened
            nn.Conv2d(xin_channels, 8, 3, stride=2, padding=1), nn.ELU(),
            # nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(8, 16, 5, stride=2, padding=2), nn.ELU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 8, 7, stride=2, padding=3), nn.ELU(),
            nn.Conv2d(8, 1, 5, stride=2, padding=2), nn.ELU(),
            nn.AdaptiveAvgPool2d((5,4)),
            # # nn.Conv2d(in_channels, out_channels=1, kernel_size=3, stride=2, padding=1),
            # nn.ReLU(),
            )
        self.enc_y = nn.Sequential( # embed pi (240, 256, 3) -> 256 when flattened
            nn.Conv2d(xin_channels, 8, 3, stride=2, padding=1), nn.ELU(),
            # nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(8, 16, 5, stride=2, padding=2), nn.ELU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 8, 7, stride=2, padding=3), nn.ELU(),
            nn.Conv2d(8, 1, 5, stride=2, padding=2), nn.ELU(),
            nn.AdaptiveAvgPool2d((5,4)),
            # # nn.Conv2d(in_channels, out_channels=1, kernel_size=3, stride=2, padding=1),
            # nn.ReLU(),
            )
        self.pred = nn.Sequential(
            nn.Linear(dim_sx + dim_z, dim_sy),
            nn.ReLU(True),
            )
        self.exp_x = nn.Sequential(
            nn.Linear(dim_sx, dim_v),
            nn.ReLU(True),
            )
        self.exp_y = nn.Sequential(
            nn.Linear(dim_sy, dim_v),
            nn.ReLU(True),
            )    

    # def zreg(self, z):
    #     loss=0
    #     dim_z=len(z)
    #     # dim_z=self.dim_z
    #     for i in range(dim_z):
    #         z_=z.copy()
    #         z_[i:]=0
    #         sx = self.enc_x(x)
    #         sy_ = self.pred(sx, z_)
    #         sy = self.enc_y(y)
    #         # loss(sy, sy_)
    #         mseloss = nn.MSELoss()(sy, sy_)
    #         loss+=mseloss
    #     return loss

    def vicreg(self, x, y): # https://github.com/facebookresearch/vicreg/blob/main/main_vicreg.py
        # x = self.projector(self.backbone(x))
        # y = self.projector(self.backbone(y))

        # invariance loss
        repr_loss = F.mse_loss(x, y)

        # x = torch.cat(FullGatherLayer.apply(x), dim=0)
        # y = torch.cat(FullGatherLayer.apply(y), dim=0)
        x = x - x.mean(dim=0)
        y = y - y.mean(dim=0)

        # variance loss
        std_x = torch.sqrt(x.var(dim=0) + 0.0001)
        std_y = torch.sqrt(y.var(dim=0) + 0.0001)
        std_loss = torch.mean(F.relu(1 - std_x)) / 2 + torch.mean(F.relu(1 - std_y)) / 2

        # # covariance loss
        # cov_x = (x.T @ x) / (self.args.batch_size - 1)
        # cov_y = (y.T @ y) / (self.args.batch_size - 1)
        # cov_loss = off_diagonal(cov_x).pow_(2).sum().div(self.num_features)\
        #  + off_diagonal(cov_y).pow_(2).sum().div(self.num_features)

        # loss = (self.args.sim_coeff * repr_loss + self.args.std_coeff * std_loss + self.args.cov_coeff * cov_loss)

        batch_size=3
        num_features=3
        sim_coeff=1
        std_coeff=1
        cov_coeff=1

        print("x.dim()",x.dim())
        if x.dim() == 1:
            x = x.view(-1, 1)
        
        if y.dim() == 1:
            y = y.view(-1, 1)
        x=x.T
        y=y.T
        print("x",x.shape)
        cov_x = (x.T @ x) / (batch_size - 1)
        cov_y = (y.T @ y) / (batch_size - 1)
        print("cov_x",cov_x.shape)
        cov_loss = off_diagonal(cov_x).pow_(2).sum().div(num_features)\
         + off_diagonal(cov_y).pow_(2).sum().div(num_features)

        loss = (sim_coeff * repr_loss + std_coeff * std_loss + cov_coeff * cov_loss)
        return loss

    def argm(self, sx, sy):
        optuna.logging.set_verbosity(optuna.logging.WARNING)
        sampler = optuna.samplers.NSGAIISampler()
        # sampler = optuna.samplers.MOTPESampler()
        study = optuna.create_study(direction="maximize", sampler=sampler, pruner=optuna.pruners.MedianPruner())
        # study = optuna.create_study()
        # print("sx",sx.shape)
        # sx=sx.flatten()
        def objective(trial):
            z = trial.suggest_uniform('z', -1, 1)
            # print("z trail",sx,z)
            z=torch.tensor([z])
            sxz = torch.cat([sx, z], dim=-1)
            sy_ = self.pred(sxz)
            mseloss = nn.MSELoss()(sy, sy_)
            return mseloss
        study.optimize(objective, n_trials=100)
        st=study.best_params
        # print("st",st['z'])
        st=torch.tensor([st['z']])
        return st

    def forward(self, x, y):
        sx = self.enc_x(x)
        sy = self.enc_y(y)
        sx=sx.flatten()
        sy=sy.flatten()
        z = self.argm(sx, sy)
        sxz = torch.cat([sx, z], dim=-1)
        sy_ = self.pred(sxz)
        # loss(sy, sy_)
        mseloss = nn.MSELoss()(sy, sy_)

        # zloss = zreg(z)

        vx = self.exp_x(sx)
        vy = self.exp_y(sy)
        # print("vx",vx.shape) #[40]
        vicloss = self.vicreg(vx, vy)
        return mseloss + vicloss


xin_channels=3
dim_sx=20
dim_sy=20
dim_z=1
dim_v=40
model = JEPA(xin_channels, dim_sx, dim_sy, dim_z, dim_v)


# x=torch.rand(210, 160, 3)
# y=torch.rand(210, 160, 3)
x=torch.rand(3, 210, 160)
y=torch.rand(3, 210, 160)
inv_loss = model(x,y)
print(inv_loss)


# enc_x = nn.Sequential( # embed pi (240, 256, 3) -> 256 when flattened
#             nn.Conv2d(xin_channels, 8, 3, stride=2, padding=1), nn.ELU(),
#             # nn.MaxPool2d(kernel_size=2, stride=2),
#             nn.Conv2d(8, 16, 5, stride=2, padding=2), nn.ELU(),
#             nn.MaxPool2d(kernel_size=2, stride=2),
#             nn.Conv2d(16, 8, 7, stride=2, padding=3), nn.ELU(),
#             nn.Conv2d(8, 1, 5, stride=2, padding=2), nn.ELU(),
#             nn.AdaptiveAvgPool2d((5,4)),
#             # # nn.Conv2d(in_channels, out_channels=1, kernel_size=3, stride=2, padding=1),
#             # nn.ReLU(),
#             )
# x=torch.rand(3, 210, 160)
# sx=enc_x(x)
# print("sx.shape",sx.shape) # [1, 256] [1, 5, 4]
# sx=sx.flatten()
# print("sx.shape",sx.shape) # [20]


# pred = nn.Sequential(
#             nn.Linear(dim_sx + dim_z, dim_sy),
#             nn.ReLU(True),
#             )

# # sx =torch.rand(1, 16, 16)
# z =torch.rand(1)
# z=torch.tensor(z)
# sxz = torch.cat([sx, z], dim=-1)
# print(sxz.shape) #257
# sy_ = pred(sxz)



x.dim() 1
x torch.Size([1, 40])
cov_x torch.Size([40, 40])
off_diagonal torch.Size([40, 40])
off_diagonal torch.Size([40, 40])
tensor(0.9638, grad_fn=<AddBackward0>)
