In [None]:
import torch
import torch.nn.functional as F
from matplotlib import cm
import imageio.v3 as iio
import torch.nn as nn
import numpy as np

import monai
from model_nerv import Generator

from monai.networks.nets.segresnet import SegResNet
args = torch.load('args_nerv.pth')
args.embed_length = 2*2*8*2
args.fc_hw_dim = '3_3_16'

import os

In [None]:


import torch
from torchvision.transforms import Resize,CenterCrop
from tqdm.notebook import trange,tqdm
from scipy import interpolate
import imageio.v3 as iio

In [None]:
#DOWNLOAD data from  https://cloud.imi.uni-luebeck.de/s/tsMSd8wSAQTKYxx

In [None]:
img_paired = torch.load('img_paired_nlst.pth')

In [None]:
import torchvision
from torchvision.models import resnet34
from torchvision.models.resnet import ResNet34_Weights
from torchvision.transforms import v2
transforms_train = v2.Compose([v2.RandomHorizontalFlip(p=0.5),
    v2.RandomPhotometricDistort(p=0.5),v2.RandomErasing(p=0.5),])
transforms_val = nn.Identity()

img_paired = img_paired.cuda().float()

def infoNCE(output):
    B = output.shape[0]//2
    cos_sim = cos(output.unsqueeze(0), output.unsqueeze(1))#
    sim = cos_sim[B:,:B]
    self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
    cos_sim.masked_fill_(self_mask, -9e15)
    # Find positive example -> batch_size//2 away from the original example
    pos_mask = self_mask.roll(shifts=cos_sim.shape[0]//2, dims=0)
    # InfoNCE loss
    temperature=0.07
    cos_sim = cos_sim / temperature
    nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
    loss = nll.mean()
    return loss,sim




In [None]:
#TRAIN re-ident directly on images
model_ = resnet34(weights=ResNet34_Weights.DEFAULT)
model_.fc = nn.Linear(512,256,bias=True)
model_.cuda()
model = torch.compile(model_)
model.train()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,4000,0.2)
scaler = torch.cuda.amp.GradScaler()

iterations = 8000
run_sim = torch.zeros(iterations)
run_loss = torch.zeros(iterations)
cos = nn.CosineSimilarity(dim=2, eps=1e-6)
B = 32
with trange(iterations) as pbar:
    for i in pbar:
        idx = torch.randperm(114)[:B]
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            ij = torch.randperm(2)[:2]; ij1 = int(ij[0]); ij2 = int(ij[1]);
            grid = F.affine_grid(torch.eye(2,3).unsqueeze(0).cuda()+torch.randn(B*2,2,3).cuda()*.075,(B*2,1,360,360))
            with torch.no_grad():
                inputs = transforms_train(F.grid_sample(torch.cat((img_paired[idx,ij1:ij1+1],img_paired[idx,ij2:ij2+1]),0),grid))
            output = torch.sigmoid(model(inputs.repeat(1,3,1,1)))
            loss,sim = infoNCE(output)
        
        run_loss[i] = loss.item()
        run_sim[i] = (sim.argmax(0)==torch.arange(B).cuda()).float().mean()#sim.trace()/4
        scaler.scale(loss).backward()
        #scaler.unscale_(optimizer)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        if(i>10):
            pbar.set_postfix(loss=run_loss[i-10:i-1].mean().item(),sim=run_sim[i-10:i-1].mean().item(),vram='%0.2f'%(torch.cuda.max_memory_allocated()*1e-9))
        

            
#torch.save(model_.state_dict(),'resnet2_nlst.pth')

In [None]:
#function to perform re-ident with test time augmentation (TTA)
def run_reident_pair(model,img_paired_test,quad=False):
    top15s = {}


    N_t = img_paired_test.shape[0]
    H,W = img_paired_test.shape[-2:]
    B = 8
    if(quad):
        cos_sim = torch.zeros(N_t*4,N_t*4).cuda()
    else:
        cos_sim = torch.zeros(N_t*2,N_t*2).cuda()
    cos = nn.CosineSimilarity(dim=2, eps=1e-6)
    val_feat0 = torch.zeros(N_t,256).cuda()
    val_feat1 = torch.zeros(N_t,256).cuda()
    val_feat2 = torch.zeros(N_t,256).cuda()
    val_feat3 = torch.zeros(N_t,256).cuda()

    top15 = torch.zeros(15,25)
    #top5 = torch.zeros(25)

    for r in trange(25):
        for i in range(N_t//8):
            with torch.no_grad():
                idx = torch.arange(i*8,i*8+8)
                with torch.cuda.amp.autocast():
                    if(quad):
                        grid = F.affine_grid(torch.eye(2,3).unsqueeze(0).cuda()+torch.randn(B*4,2,3).cuda()*.05,(B*4,1,240,240))
                        inputs = transforms_train(F.grid_sample(torch.cat((img_paired[idx,:1],img_paired[idx,1:2],img_paired[idx,2:3],img_paired[idx,3:4]),0),grid))
                    else:
                        grid = F.affine_grid(torch.eye(2,3).unsqueeze(0).cuda()+torch.randn(B*2,2,3).cuda()*.05,(B*2,1,H,W))
                        inputs = transforms_train(F.grid_sample(torch.cat((img_paired_test[idx,:1],img_paired_test[idx,1:2]),0),grid))
                    output = torch.sigmoid(model(inputs.repeat(1,3,1,1)))


                    val_feat0[i*8:(i+1)*8] = output[:8]
                    val_feat1[i*8:(i+1)*8] = output[8:16]
                    if(quad):
                        val_feat2[i*8:(i+1)*8] = output[16:24]
                        val_feat3[i*8:(i+1)*8] = output[24:32]
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                if(quad):
                    cos_sim1 = cos(torch.cat((val_feat0,val_feat1,val_feat2,val_feat3)).unsqueeze(0),torch.cat((val_feat0,val_feat1,val_feat2,val_feat3)).unsqueeze(1))
                else:
                    cos_sim1 = cos(torch.cat((val_feat0,val_feat1)).unsqueeze(0),torch.cat((val_feat0,val_feat1)).unsqueeze(1))
        cos_sim += cos_sim1
        cos_sim_ = cos_sim.clone()
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
        cos_sim_.masked_fill_(self_mask, 0)
        idx = torch.argsort(cos_sim_[:N_t],descending=True,dim=1)
        rank = torch.nonzero(idx==torch.arange(N_t,2*N_t).cuda().view(-1,1))[:,1].data.float().cpu()
        if(quad):
            rank = torch.minimum(rank,torch.nonzero(idx==torch.arange(2*N_t,3*N_t).cuda().view(-1,1))[:,1].data.float().cpu())
            rank = torch.minimum(rank,torch.nonzero(idx==torch.arange(3*N_t,4*N_t).cuda().view(-1,1))[:,1].data.float().cpu())
        for rr in range(15):
            top15[rr,r] = (rank<=rr).float().mean()

    print(' top1 acc',(rank==0).float().mean(),'top 5 acc',(rank<5).float().mean())
    return top15


In [None]:
#evaluate on hold-out test set
img_paired_test = img_paired[114:].clone()
top15 = run_reident_pair(model_,img_paired_test,quad=False)
#returns #IMAGE:  top1 acc tensor(0.5312) top 5 acc tensor(0.7292)


In [None]:
#torch.save(model_.state_dict(),'nlst_reident_img.pth')

In [None]:
#setup NeRV model
from tqdm import tqdm,trange
import time
from skimage.metrics import peak_signal_noise_ratio as psnr
torch.set_float32_matmul_precision('high')
from math import exp

from model_nerv import Generator

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def _ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)
#and SSIM similarity 
class SSIM(torch.nn.Module):
    def __init__(self, window_size = 11, size_average = True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)
            
            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)
            
            self.window = window
            self.channel = channel


        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)

def ssim(img1, img2, window_size = 11, size_average = True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)
    
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)
    
    return _ssim(img1, img2, window, window_size, channel, size_average)

def fit_nerv(imgs_in,N_case):
    args = torch.load('args_nerv.pth')
    args.embed_length = 2*2*8*2
    args.fc_hw_dim = '3_3_16'
    
    #N_case = 64
    folds = imgs_in.shape[0]//N_case
    models = []
    for subfold in range(folds):

        model = Generator(embed_length=args.embed_length, stem_dim_num=args.stem_dim_num, fc_hw_dim=args.fc_hw_dim, expansion=args.expansion, 
            num_blocks=args.num_blocks, norm=args.norm, act=args.act, bias = True, reduction=args.reduction, conv_type=args.conv_type,
            stride_list=args.strides,  sin_res=args.single_res,  lower_width=args.lower_width, sigmoid=args.sigmoid)
        model.cuda()

        #model = torch.compile(model_)

        embed_wb = nn.Embedding(N_case,64).cuda()
        optimizer = torch.optim.Adam(list(model.parameters())+list(embed_wb.parameters()),lr=0.001)

        num_iterations = 2500
        batch_size = 16
        run_loss = torch.zeros(num_iterations)
        run_psnr = torch.zeros(num_iterations)
        with tqdm(total=num_iterations, file=sys.stdout) as pbar:
            for i in range(num_iterations):
                optimizer.zero_grad();
                idx = torch.randperm(N_case)[:batch_size].cuda()
                target = imgs_in[idx.cpu()+subfold*N_case,0].cuda()
                with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                    code = embed_wb(idx)
                    x = model(code)[-1][:,0]
                loss = 1-ssim(target.unsqueeze(1).float(),x.unsqueeze(1).float())
                loss.backward()
                optimizer.step()
                
                run_psnr[i] = psnr(torch.clamp(target.float().cpu().data.view(-1),0,1).numpy(),torch.clamp(x.float().cpu().data.reshape(-1),0,1).numpy(),)
                run_loss[i] = loss.item()
                str1 = f"iter: {i}, loss: {'%0.3f'%(run_loss[i-28:i-1].mean().mul(100))}, psnr: {'%0.3f'%(run_psnr[i-28:i-1].mean().mul(1))}, GPU max/memory: {'%0.2f'%(torch.cuda.max_memory_allocated()*1e-9)} GByte"
                pbar.set_description(str1)
                pbar.update(1)
        model.half()
        models.append([model.state_dict(),embed_wb.state_dict(),run_psnr])
    return models
       

In [None]:
#train/fit NeRVs separately for first and second halves of pairs
nerv_models = fit_nerv(img_paired,35)
nerv_models1 = fit_nerv(img_paired[:,1:],35)

In [None]:
#torch.save(nerv_models,'nlst_reident_nervs.pth')
#torch.save(nerv_models1,'nlst_reident_nervs1.pth')

In [None]:
#k-anonymity mixing
def k_anonymous(nerv_models,shape,N_case):
    rho = 8
    img_nerv = torch.zeros(shape).cuda()
    H,W = shape[-2:]
    folds = shape[0]//N_case

    model = Generator(embed_length=args.embed_length, stem_dim_num=args.stem_dim_num, fc_hw_dim=args.fc_hw_dim, expansion=args.expansion, 
    num_blocks=args.num_blocks, norm=args.norm, act=args.act, bias = True, reduction=args.reduction, conv_type=args.conv_type,
    stride_list=args.strides,  sin_res=args.single_res,  lower_width=args.lower_width, sigmoid=args.sigmoid)
    model.cuda()
    for sub in range(len(nerv_models)):
        state_dicts = nerv_models[sub]
        embed_wb = nn.Embedding(N_case,64).cuda()
        embed_wb.load_state_dict(state_dicts[1])
        model.load_state_dict(state_dicts[0])
        with torch.no_grad():
            idx = torch.arange(N_case).cuda()
            with torch.cuda.amp.autocast(dtype=torch.bfloat16):
                code = (torch.eye(N_case)+torch.randn(N_case,N_case)*.01*float(rho)).cuda().mm(embed_wb(idx))
                x = model(code)[-1][:,0]
                img_nerv[idx+sub*N_case,0] = x.view(-1,H,W).float()

    return img_nerv
img_nerv = k_anonymous(nerv_models,img_paired.shape,35)
img_nerv1 = k_anonymous(nerv_models1,img_paired.shape,35)

In [None]:
img_nerv_paired = torch.cat((img_nerv[:,:1],img_nerv1[:,:1]),1)

In [None]:
#train new re-ident model on NeRVed images
model_ = resnet34(weights=ResNet34_Weights.DEFAULT)
model_.fc = nn.Linear(512,256,bias=True)
model_.cuda()
model = torch.compile(model_)
model.train()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,4000,0.2)
scaler = torch.cuda.amp.GradScaler()

iterations = 8000
run_sim = torch.zeros(iterations)
run_loss = torch.zeros(iterations)
cos = nn.CosineSimilarity(dim=2, eps=1e-6)
B = 32
with trange(iterations) as pbar:
    for i in pbar:
        idx = torch.randperm(114)[:B]
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            ij = torch.randperm(2)[:2]; ij1 = int(ij[0]); ij2 = int(ij[1]);
            grid = F.affine_grid(torch.eye(2,3).unsqueeze(0).cuda()+torch.randn(B*2,2,3).cuda()*.075,(B*2,1,360,360))
            with torch.no_grad():
                inputs = transforms_train(F.grid_sample(torch.cat((img_nerv_paired[idx,ij1:ij1+1],img_nerv_paired[idx,ij2:ij2+1]),0),grid))
            output = torch.sigmoid(model(inputs.repeat(1,3,1,1)))
            loss,sim = infoNCE(output)
        
        run_loss[i] = loss.item()
        run_sim[i] = (sim.argmax(0)==torch.arange(B).cuda()).float().mean()#sim.trace()/4
        scaler.scale(loss).backward()
        #scaler.unscale_(optimizer)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        if(i>10):
            pbar.set_postfix(loss=run_loss[i-10:i-1].mean().item(),sim=run_sim[i-10:i-1].mean().item(),vram='%0.2f'%(torch.cuda.max_memory_allocated()*1e-9))
        

            
#torch.save(model_.state_dict(),'resnet2_nerv_rho008_nlst.pth')

In [None]:
#EVALUATE
top15_nerv = run_reident_pair(model_,img_nerv_paired[114:].clone(),quad=False)

In [None]:
#IMAGE:  top1 acc tensor(0.5312) top 5 acc tensor(0.7292)
#NERV:  top1 acc tensor(0.3958) top 5 acc tensor(0.5312)
