In [1]:
# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap

# Segment anything
from Mobile_SAM.build_sam import sam_model_registry
from Mobile_SAM.utils.transforms import ResizeLongestSide

#CLIP
from models.backbones.backbone import CLIPViTFM, clip_backbone

from cluster.Group import Cluster_GPU
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import utils as vutils
import numpy as np
import yaml
from PIL import ImageFilter
import random
import json 
import cv2
import matplotlib.pyplot as plt
import os
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import builtins
from tqdm import tqdm
import time
from utils import AverageMeter, ProgressMeter, to_log, accuracy

  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)


In [2]:
class TwoWayTransform:
    def __init__(self, base_transform_a,
        base_transform_b, fixed_aug_shot=True):
        self.base_transform_a = base_transform_a
        self.base_transform_b = base_transform_b
        self.fixed = fixed_aug_shot

    def __call__(self, x):
        frame_num = len(x)
        if self.fixed:
            seed = np.random.randint(2147483647)
            q, k = [], []
            for i in range(frame_num):
                random.seed(seed)
                q.append(self.base_transform_a(x[i]))
            seed = np.random.randint(2147483647)
            for i in range(frame_num):
                random.seed(seed)
                k.append(self.base_transform_b(x[i]))
        else:
            q = [self.base_transform_a(x[i]) for i in range(frame_num)]
            k = [self.base_transform_b(x[i]) for i in range(frame_num)]
        q = torch.cat(q, axis = 0)
        k = torch.cat(k, axis = 0)
        return [q, k]

class MovieNet_Shot_Dataset(torch.utils.data.Dataset):
    def __init__(self, img_path, shot_info_path, transform,
        shot_len = 16, frame_per_shot = 3, _Type='train'):
        self.img_path = img_path
        with open(shot_info_path, 'rb') as f:
            self.shot_info = json.load(f)
        self.img_path = img_path
        self.shot_len = shot_len
        self.frame_per_shot = frame_per_shot
        self.transform = transform
        self._Type = _Type.lower()
        assert self._Type in ['train','val','test']
        self.idx_imdb_map = {}
        data_length = 0
        for imdb, shot_num in self.shot_info[_Type].items():
            for i in range(shot_num // shot_len):
                self.idx_imdb_map[data_length] = (imdb, i)
                data_length += 1
                
    def __len__(self):
        return len(self.idx_imdb_map.keys())
    
    def _transform(self, img_list):
        q, k = [], []
        for item in img_list:
            out = self.transform(item)
            q.append(out[0])
            k.append(out[1])
        out_q = torch.stack(q, axis=0)
        out_k = torch.stack(k, axis=0)
        return [out_q, out_k]
    
    def _process_puzzle(self, idx):
        imdb, puzzle_id = self.idx_imdb_map[idx]
        img_path =  f'{self.img_path}/{imdb}/{str(puzzle_id).zfill(4)}.jpg'
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = np.vsplit(img, self.shot_len)
        img = [np.hsplit(i, self.frame_per_shot) for i in img]
        data = self._transform(img)
        return data
    
    def __getitem__(self, idx):
        return self._process_puzzle(idx)

def get_train_loader(cfg):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225]) 
    augmentation_base = [
        transforms.ToPILImage(),
        transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        #normalize
    ]

    train_transform = TwoWayTransform(
        transforms.Compose(augmentation_base), 
        transforms.Compose(augmentation_base),
        fixed_aug_shot=cfg['data']['fixed_aug_shot'])

    img_path = cfg['data']['data_path'] 
    shot_info_path = cfg['data']['shot_info'] 
    train_dataset = MovieNet_Shot_Dataset(img_path, shot_info_path, train_transform)
    train_sampler = None
    if False:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True)
    train_loader = torch.utils.data.DataLoader(train_dataset, 
        shuffle=(train_sampler is None),
        batch_size=16, num_workers=8, pin_memory=True, drop_last=True) #36
    return train_loader

In [3]:
cfg = yaml.safe_load(open("./config/vllip_pretrain.yaml", encoding='utf8'))
def load_dino(model_config_path, model_checkpoint_path, device):
    args = SLConfig.fromfile(model_config_path)
    args.device = device
    model = build_model(args)
    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
    print(load_res)
    _ = model.eval()
    return model
dino = load_dino(cfg["model"]["dino_config_path"], cfg["model"]["dino_pretrain"], device=f'cuda:{torch.cuda.current_device()}')
sam = sam_model_registry["vit_t"](checkpoint=cfg["model"]["sam_pretrain"]).cuda()
clip_vit = CLIPViTFM(batch_size=16*16, model_name='ViT-B/32').cuda()
#clip_res = clip_backbone(batch_size=16*16, model_name='RN50').cuda()
train_loader = get_train_loader(cfg)  
print(int(torch.cuda.max_memory_allocated(device='cuda')/1024**2),"MB") #DEBUG

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


final text_encoder_type: bert-base-uncased
_IncompatibleKeys(missing_keys=[], unexpected_keys=['label_enc.weight', 'bert.embeddings.position_ids'])
625 MB


In [4]:
class VLLIP_vit_encoder(nn.Module):
    def __init__(self,dino, sam, clip):
        super(VLLIP_vit_encoder, self).__init__()
        self.dino = dino
        self.sam = sam
        self.clip = clip
         
    def forward(self, images, type):
        #shots = images.clone()
        images = images.reshape(images.size()[0]*3, 3, 224, 224)
        assert type =='train' or type=='test'
        if type=='train':
            mask_list = torch.ones(images.size()[0], 1, 224, 224)
            #mask_list = mask_list.reshape(int(images.size()[0]/3),3,224,224)

        elif type=='test':    
            # Ground DINO
            boxes_filt, _, _ = self.get_grounding_output(
                self.dino, images, "human", 0.45, 0.25)
            torch.cuda.empty_cache()
            
            # SAM
            images = (images * 255).clamp(0, 255).to(torch.uint8) # bs, 3, 224, 224
            resize_transform = ResizeLongestSide(self.sam.image_encoder.img_size)
        
            """
            for i in range(10,14): # !!!DEBUG!!! 
                plt.imshow(images[i].permute(1, 2, 0))
                plt.show()
            """
        
            batched_input = []
            for i in range(images.size()[0]):                
                prepared_image = self.prepare_image(images[i], resize_transform)   
                transformed_boxes = self.transform_boxes(boxes_filt[i],images[i],resize_transform)
                batched_input.append({
                    'image': prepared_image,
                    'boxes': transformed_boxes,
                    'original_size': images[i].shape[1:] 
                })
            with torch.no_grad():    
                outputs = self.sam(batched_input, multimask_output=False)
            
            mask_list = []
            for i, output in enumerate(outputs):
                if output["masks"] != None:
                    output["masks"] = output["masks"].cpu()
                    output["masks"] = output["masks"].any(dim=0)
                    output["masks"] = (~output["masks"]).int()
                    mask_list.append(output["masks"])
                else:
                    mask_list.append(torch.ones([1, 224, 224]))
            
            torch.cuda.empty_cache()
            mask_list = torch.stack(mask_list)
        
        """
        main_color = np.array([255, 255, 255, 255]) #배경색 - 흰색
        subtract_color = np.array([255, 255, 255, 0]) #마스크(사람) - 검정색
        for i in range(10,14): # !!!DEBUG!!! 
            H, W =images[i].size()[1], images[i].size()[2] #224, 224
            mask_data = mask_list[i].numpy()
            bg_image = np.full((H, W, 4), main_color, dtype=np.float64)
            mask_image = mask_data.reshape(H, W, 1) * subtract_color.reshape(1, 1, -1)
            final_image = bg_image.copy()
            final_image -= mask_image
            plt.imshow(final_image)
            plt.show()
        """

        #CLIP ViT
        visual_features = self.clip(images.cuda(), mask_list.cuda(), masking_type='token_masking', masking_block=9)
        visual_features = visual_features.cpu()
        torch.cuda.empty_cache()
        return visual_features

    @torch.no_grad()
    def get_grounding_output(self, model, image, caption, box_threshold, text_threshold, with_logits=True):
        caption = caption.lower()
        caption = caption.strip()
        if not caption.endswith("."):
            caption = caption + "."
        captions = [caption]*image.size()[0]
        images = image.cuda()
        model = model.cuda()
        with torch.no_grad():
            outputs = model(images, captions=captions)

        prediction_logits = outputs["pred_logits"].cpu().sigmoid()  # (bs, nq, 256)
        prediction_boxes = outputs["pred_boxes"].cpu()  # (bs, nq, 4)
        
        logits_res = [] #length: bs
        boxs_res = [] #length: bs
        phrases_list = [] #length: bs
        tokenizer = model.tokenizer
        for ub_logits, ub_boxes, ub_captions in zip(prediction_logits, prediction_boxes, captions):
            mask = ub_logits.max(dim=1)[0] > box_threshold
            logits = ub_logits[mask]  # logits.shape = (n, 256)
            boxes = ub_boxes[mask]  # boxes.shape = (n, 4)
            logits_res.append(logits.max(dim=1)[0])
            boxs_res.append(boxes)

            tokenized = tokenizer(ub_captions)
            phrases = [
                get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
                for logit
                in logits
            ]
            phrases_list.append(phrases)
        return boxs_res, phrases_list, logits_res,
    @torch.no_grad()
    def transform_boxes(self, boxes_filt, image, transform):
        H, W = image.size()[1], image.size()[2] #224, 224
        for i in range(boxes_filt.size(0)): #XYWH -> X1Y1X2Y2 형식 변경
            boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
            boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
            boxes_filt[i][2:] += boxes_filt[i][:2]
        boxes_filt = boxes_filt.cpu()
        transformed_boxes = transform.apply_boxes_torch(boxes_filt, image.shape[1:]).cuda()
        return transformed_boxes
    @torch.no_grad()
    def prepare_image(self, image, transform):
        image = transform.apply_image(image) #3, 224, 224 -> 1024, 1024, 3
        image = torch.as_tensor(image) #to cuda
        image = image.cuda()
        return image.permute(2, 0, 1).contiguous() #1024, 1024, 3 -> 3, 1024, 1024

In [4]:
class VLLIP_res_encoder(nn.Module):
    def __init__(self,dino, sam, clip):
        super(VLLIP_res_encoder, self).__init__()
        self.dino = dino
        self.sam = sam
        self.clip = clip
         
    def forward(self, images, type):
        shots = images.clone()
        images = images.reshape(images.size()[0]*3, 3, 224, 224)
        assert type =='train' or type=='test'
        if type=='train':
            mask_list = torch.ones(images.size()[0], 1, 224, 224)
            mask_list = mask_list.reshape(int(images.size()[0]/3),3,224,224)

        elif type=='test':    
            # Ground DINO
            boxes_filt, _, _ = self.get_grounding_output(
                self.dino, images, "human", 0.45, 0.25)
            torch.cuda.empty_cache()
            
            # SAM
            images = (images * 255).clamp(0, 255).to(torch.uint8) # bs, 3, 224, 224
            resize_transform = ResizeLongestSide(self.sam.image_encoder.img_size)
        
            """
            for i in range(10,14): # !!!DEBUG!!! 
                plt.imshow(images[i].permute(1, 2, 0))
                plt.show()
            """
        
            batched_input = []
            for i in range(images.size()[0]):                
                prepared_image = self.prepare_image(images[i], resize_transform)   
                transformed_boxes = self.transform_boxes(boxes_filt[i],images[i],resize_transform)
                batched_input.append({
                    'image': prepared_image,
                    'boxes': transformed_boxes,
                    'original_size': images[i].shape[1:] 
                })
            with torch.no_grad():    
                outputs = self.sam(batched_input, multimask_output=False)
            
            mask_list = []
            for i, output in enumerate(outputs):
                if output["masks"] != None:
                    output["masks"] = output["masks"].cpu()
                    output["masks"] = output["masks"].any(dim=0)
                    output["masks"] = (~output["masks"]).int()
                    mask_list.append(output["masks"])
                else:
                    mask_list.append(torch.ones([1, 224, 224]))
            
            torch.cuda.empty_cache()
            mask_list = torch.stack(mask_list)

        #CLIP ResNet
        visual_features = self.clip.feature_map_masking(shots.cuda(), mask_list.cuda())
        visual_features.cpu()
        torch.cuda.empty_cache()
        return visual_features

    @torch.no_grad()
    def get_grounding_output(self, model, image, caption, box_threshold, text_threshold, with_logits=True):
        caption = caption.lower()
        caption = caption.strip()
        if not caption.endswith("."):
            caption = caption + "."
        captions = [caption]*image.size()[0]
        images = image.cuda()
        model = model.cuda()
        with torch.no_grad():
            outputs = model(images, captions=captions)

        prediction_logits = outputs["pred_logits"].cpu().sigmoid()  # (bs, nq, 256)
        prediction_boxes = outputs["pred_boxes"].cpu()  # (bs, nq, 4)
        
        logits_res = [] #length: bs
        boxs_res = [] #length: bs
        phrases_list = [] #length: bs
        tokenizer = model.tokenizer
        for ub_logits, ub_boxes, ub_captions in zip(prediction_logits, prediction_boxes, captions):
            mask = ub_logits.max(dim=1)[0] > box_threshold
            logits = ub_logits[mask]  # logits.shape = (n, 256)
            boxes = ub_boxes[mask]  # boxes.shape = (n, 4)
            logits_res.append(logits.max(dim=1)[0])
            boxs_res.append(boxes)

            tokenized = tokenizer(ub_captions)
            phrases = [
                get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
                for logit
                in logits
            ]
            phrases_list.append(phrases)
        return boxs_res, phrases_list, logits_res,
    @torch.no_grad()
    def transform_boxes(self, boxes_filt, image, transform):
        H, W = image.size()[1], image.size()[2] #224, 224
        for i in range(boxes_filt.size(0)): #XYWH -> X1Y1X2Y2 형식 변경
            boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
            boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
            boxes_filt[i][2:] += boxes_filt[i][:2]
        boxes_filt = boxes_filt.cpu()
        transformed_boxes = transform.apply_boxes_torch(boxes_filt, image.shape[1:]).cuda()
        return transformed_boxes
    @torch.no_grad()
    def prepare_image(self, image, transform):
        image = transform.apply_image(image) #3, 224, 224 -> 1024, 1024, 3
        image = torch.as_tensor(image) #to cuda
        image = image.cuda()
        return image.permute(2, 0, 1).contiguous() #1024, 1024, 3 -> 3, 1024, 1024

In [5]:
class VLLIP(nn.Module):
    def __init__(self, dino, sam, clip, type):
        super(VLLIP,self).__init__()
        self.dino = dino
        self.sam = sam
        self.clip = clip
        self.type = type

        self.q_encoder = VLLIP_vit_encoder(self.dino, self.sam, self.clip)
        self.k_encoder = VLLIP_vit_encoder(self.dino, self.sam, self.clip)
        #self.q_encoder = VLLIP_res_encoder(self.dino, self.sam, self.clip)
        #self.k_encoder = VLLIP_res_encoder(self.dino, self.sam, self.clip)
        
        self.cluster_num = 24
        self.cluster_obj = Cluster_GPU(self.cluster_num)
        self.multi_positive = True
        self.soft_gamma = 0.5
        self.K = 16384 #32768 #16384 #65536
        self.m = 0.999
        self.T = 0.07
        self.dim = 512
        
        for param_q, param_k in zip(self.q_encoder.clip.parameters(), self.k_encoder.clip.parameters()): #Q인코더 K인코더로 파라미터 복사 후 freeze
            param_k.data.copy_(param_q.data)  
            param_k.requires_grad = False 
        
        # create the queue
        self.register_buffer("queue", torch.randn(self.dim, self.K)) # 512 x 16384 행렬
        self.queue = nn.functional.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
        # queue 길이 = 16384, queue 차원 = 512

    def forward(self, img_q, img_k): #SCRL과 동일함
        embeddings = self.q_encoder(img_q, self.type) # bs, 512
        embeddings = nn.functional.normalize(embeddings, dim=1) # bs, 512
        
        # get q and k index
        index_q, index_k = self.get_q_and_k_index_cluster(embeddings) # q 임베딩 입력 -> 0~bs, k-index
        
        # features of q
        q = embeddings[index_q] # embeddings
        q = q.cuda()

        # compute key features
        with torch.no_grad():  
            # update the key encoder
            self._momentum_update_key_encoder() #k encoder param update

            # shuffle for making use of BN
            # img_k, idx_unshuffle = self._batch_shuffle_ddp(img_k)

            k = self.k_encoder(img_k, self.type) # bs, 512
            k = nn.functional.normalize(k, dim=1)
            k = k.cuda()
            
            # undo shuffle
            # k = self._batch_unshuffle_ddp(k, idx_unshuffle)

        k_ori = k
        k = k[index_k] #k indexes 선택

        # compute logits
        # positive logits: Nx1
        if self.multi_positive: # True
            # SCRL Soft-SC 
            k = (k + k_ori) * self.soft_gamma # original_k + k

        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) #내적 - attntion 같은 느낌

        
        # negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) #외적

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1) #연접

        # apply temperature
        logits /= self.T #scalinng

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() #전부 0 나옴

        # dequeue and enqueue
        self._dequeue_and_enqueue(k)
        
        return logits, labels

    @torch.no_grad()
    def get_q_and_k_index_cluster(self, embeddings, return_group=False) -> tuple:

        B = embeddings.size(0) #bs
        target_index = list(range(0, B)) #0~bs
        q_index = target_index #0~bs

        choice_cluster, choice_points = self.cluster_obj(embeddings) #clustering
        k_index = []
        for c in choice_cluster:
            k_index.append(int(choice_points[c])) #cluster centers
        if return_group: #False
            return (q_index, k_index, choice_cluster, choice_points)
        else:
            return (q_index, k_index)

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.q_encoder.clip.parameters(), self.k_encoder.clip.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
    

    
    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        # keys = concat_all_gather(keys)

        batch_size = keys.shape[0]
        
        ptr = int(self.queue_ptr)
        
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    """
    @torch.no_grad()
    def concat_all_gather(tensor):
        
        #Performs all_gather operation on the provided tensors.
        #*** Warning ***: torch.distributed.all_gather has no gradient.
        
        tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

        output = torch.cat(tensors_gather, dim=0)
        return output

    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        
        #Batch shuffle, for making use of BatchNorm.
        #*** Only support DistributedDataParallel (DDP) model. ***
        
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle
    
    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        
        #Undo batch shuffle.
        #*** Only support DistributedDataParallel (DDP) model. ***
        
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]
    """

    @torch.no_grad()
    def get_q_and_k_index_cluster(self, embeddings, return_group=False) -> tuple:

        B = embeddings.size(0) # bs
        target_index = list(range(0, B)) # 0...bs-1
        q_index = target_index # 0...bs-1

        choice_cluster, choice_points = self.cluster_obj(embeddings)
        k_index = []
        for c in choice_cluster:
            k_index.append(int(choice_points[c]))
        if return_group:
            return (q_index, k_index, choice_cluster, choice_points)
        else:
            return (q_index, k_index)


In [6]:
import models.backbones.visual.resnet as resnet
from models.core.SCRL_MoCo import SCRL

import wandb

wandb.init(project='VLLIP', entity='nstar1125')
wandb.run.name = 'test_run-1'
wandb.run.save()

os.environ["TOKENIZERS_PARALLELISM"] = "false"


model = VLLIP(
                dino=dino,
                sam=sam,
                clip=clip_vit,
                type = 'train'
            ).cuda()

model.train()
model.clip.train()
optimizer = torch.optim.SGD(model.clip.parameters(), cfg['optim']['lr'],
                                    momentum=cfg['optim']['momentum'],
                                    weight_decay=cfg['optim']['wd'])
"""
o_model = SCRL(
            base_encoder                = resnet.encoder_resnet50,
            dim                         = cfg['MoCo']['dim'], 
            K                           = cfg['MoCo']['k'], 
            m                           = cfg['MoCo']['m'], 
            T                           = cfg['MoCo']['t'], 
            mlp                         = True, 
            encoder_pretrained_path     = "./pretrain/resnet50-19c8e357.pth",
            multi_positive              = cfg['MoCo']['multi_positive'],
            positive_selection          = cfg['model']['Positive_Selection'],
            cluster_num                 = cfg['model']['cluster_num'],
            soft_gamma                  = cfg['model']['soft_gamma'],
            ).cuda()
o_model.train()
optimizer = torch.optim.SGD(o_model.parameters(), cfg['optim']['lr'],
                                    momentum=cfg['optim']['momentum'],
                                    weight_decay=cfg['optim']['wd'])
"""
criterion = torch.nn.CrossEntropyLoss()

print(int(torch.cuda.memory_allocated(device='cuda')/1024**2),"MB") #DEBUG

for epoch in range(4):
    #for i, data in tqdm(enumerate(train_loader),total=len(train_loader),ncols=100, position=0, leave=True):
    for i, data in enumerate(train_loader):
        pivot = time.time()
        view_size = (-1, 3 * cfg['data']['frame_size'], 224, 224)
        data_q = data[0].cuda(torch.cuda.current_device(), non_blocking=True)
        data_k = data[1].cuda(torch.cuda.current_device(), non_blocking=True)
        data_q = data_q.view(view_size)
        data_k = data_k.view(view_size) 
    
        output, target = model(data_q, data_k) 
        loss = criterion(output, target)
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        loss.requires_grad_(True)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"[{epoch} epoch]({i}/{len(train_loader)} iteration) lr: {optimizer.param_groups[0]['lr']:.4f}, loss: {loss:.4f}, acc1: {acc1[0]:.4f}, acc5: {acc5[0]:.4f}, elapsed_time: {(time.time()-pivot):.4f}s")
        wandb.log({
            "lr": optimizer.param_groups[0]['lr'],
            "loss": loss,
            "acc1": acc1,
            "acc5": acc5
        })
        pivot = time.time()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mnstar1125[0m. Use [1m`wandb login --relogin`[0m to force relogin




1316 MB
[0 epoch](0/3742 iteration) lr: 0.0000, loss: 0.9928, acc1: 100.0000, acc5: 100.0000, elapsed_time: 1.9559s
[0 epoch](1/3742 iteration) lr: 0.0000, loss: 3.0997, acc1: 87.1094, acc5: 93.3594, elapsed_time: 1.6778s
[0 epoch](2/3742 iteration) lr: 0.0000, loss: 3.7281, acc1: 77.3438, acc5: 89.4531, elapsed_time: 1.6112s
[0 epoch](3/3742 iteration) lr: 0.0000, loss: 3.9636, acc1: 82.4219, acc5: 94.1406, elapsed_time: 1.8219s
[0 epoch](4/3742 iteration) lr: 0.0000, loss: 4.1853, acc1: 72.2656, acc5: 88.2812, elapsed_time: 1.5686s
[0 epoch](5/3742 iteration) lr: 0.0000, loss: 4.5402, acc1: 63.2812, acc5: 82.0312, elapsed_time: 1.3769s
[0 epoch](6/3742 iteration) lr: 0.0000, loss: 4.5786, acc1: 72.2656, acc5: 85.9375, elapsed_time: 1.3791s
[0 epoch](7/3742 iteration) lr: 0.0000, loss: 4.6694, acc1: 75.7812, acc5: 88.2812, elapsed_time: 1.3833s
[0 epoch](8/3742 iteration) lr: 0.0000, loss: 4.6106, acc1: 82.8125, acc5: 91.7969, elapsed_time: 1.3807s
[0 epoch](9/3742 iteration) lr: 0.00

KeyboardInterrupt: 

In [10]:
for k, v in clip_res.named_parameters():
    print(k)

model.positional_embedding
model.text_projection
model.logit_scale
model.visual.conv1.weight
model.visual.bn1.weight
model.visual.bn1.bias
model.visual.conv2.weight
model.visual.bn2.weight
model.visual.bn2.bias
model.visual.conv3.weight
model.visual.bn3.weight
model.visual.bn3.bias
model.visual.layer1.0.conv1.weight
model.visual.layer1.0.bn1.weight
model.visual.layer1.0.bn1.bias
model.visual.layer1.0.conv2.weight
model.visual.layer1.0.bn2.weight
model.visual.layer1.0.bn2.bias
model.visual.layer1.0.conv3.weight
model.visual.layer1.0.bn3.weight
model.visual.layer1.0.bn3.bias
model.visual.layer1.0.downsample.0.weight
model.visual.layer1.0.downsample.1.weight
model.visual.layer1.0.downsample.1.bias
model.visual.layer1.1.conv1.weight
model.visual.layer1.1.bn1.weight
model.visual.layer1.1.bn1.bias
model.visual.layer1.1.conv2.weight
model.visual.layer1.1.bn2.weight
model.visual.layer1.1.bn2.bias
model.visual.layer1.1.conv3.weight
model.visual.layer1.1.bn3.weight
model.visual.layer1.1.bn3.bias