In [1]:
import os
import clip
import torch.nn as nn
from datasets import Action_DATASETS
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb
import argparse
import shutil
from pathlib import Path
import yaml
from dotmap import DotMap
import pprint
import numpy
from modules.Visual_Prompt import visual_prompt
from utils.Augmentation import get_augmentation
import torch
from utils.Text_Prompt import *

  from .autonotebook import tqdm as notebook_tqdm


# Utility Funcs

In [2]:
class TextCLIP(nn.Module):
    def __init__(self, model):
        super(TextCLIP, self).__init__()
        self.model = model

    def forward(self, text):
        return self.model.encode_text(text)

In [3]:
class ImageCLIP(nn.Module):
    def __init__(self, model):
        super(ImageCLIP, self).__init__()
        self.model = model

    def forward(self, image):
        return self.model.encode_image(image)

In [4]:
def validate(epoch, val_loader, classes, device, model, fusion_model, config, num_text_aug):
    model.eval()
    fusion_model.eval()
    num = 0
    corr_1 = 0
    corr_5 = 0

    with torch.no_grad():
        text_inputs = classes.to(device)
        text_features = model.encode_text(text_inputs)
        for iii, (image, class_id) in enumerate(tqdm(val_loader)):
            image = image.view((-1, config.data.num_segments, 3) + image.size()[-2:])
            b, t, c, h, w = image.size()
            class_id = class_id.to(device)
            image_input = image.to(device).view(-1, c, h, w)
            image_features = model.encode_image(image_input).view(b, t, -1)
            image_features = fusion_model(image_features)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            similarity = (100.0 * image_features @ text_features.T)
            similarity = similarity.view(b, num_text_aug, -1).softmax(dim=-1)
            similarity = similarity.mean(dim=1, keepdim=False)
            values_1, indices_1 = similarity.topk(1, dim=-1)
            values_5, indices_5 = similarity.topk(5, dim=-1)
            num += b
            for i in range(b):
                if indices_1[i] == class_id[i]:
                    corr_1 += 1
                if class_id[i] in indices_5[i]:
                    corr_5 += 1
    top1 = float(corr_1) / num * 100
    top5 = float(corr_5) / num * 100
    wandb.log({"top1": top1})
    wandb.log({"top5": top5})
    print('Epoch: [{}/{}]: Top1: {}, Top5: {}'.format(epoch, config.solver.epochs, top1, top5))
    return top1

# Main Model Run

In [18]:
args_type = 'clip_k400'
args_arch = 'ViT-B/16'
args_dataset = 'hl'
args_log_time = '20240229_154250'

# load params
args_config= "/home/regal/devel/ws_cacti/src/hri_cacti_xr/gesture_recognition/gesture_recogition_research/ActionCLIP/configs/k400/k400_zero_shot.yaml"
with open(args_config, 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)


In [19]:
working_dir = os.path.join('./exp', config['network']['type'], config['network']['arch'], config['data']['dataset'],
                               args_log_time)

In [20]:
wandb.init(project=config['network']['type'],
           name='{}_{}_{}_{}'.format(args_log_time, config['network']['type'], config['network']['arch'],
                                     config['data']['dataset']))



In [21]:
print('-' * 80)
print(' ' * 20, "working dir: {}".format(working_dir))
print('-' * 80)
print('-' * 80)
print(' ' * 30, "Config")
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(config)
print('-' * 80)

--------------------------------------------------------------------------------
                     working dir: ./exp/clip_k400/ViT-B/16/hl/20240229_154250
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
                               Config
{   'data': {   'batch_size': 64,
                'dataset': 'hl',
                'image_tmpl': 'img_{:05d}.jpg',
                'index_bias': 1,
                'input_size': 224,
                'label_list': '/home/regal/devel/ws_cacti/src/hri_cacti_xr/gesture_recognition/gesture_recogition_research/action_clip_data/hl_val_labels.csv',
                'modality': 'RGB',
                'num_classes': 3,
                'num_segments': 8,
                'random_shift': False,
                'seg_length': 1,
                'split': 1,
                'val_list': '/home/regal/devel/ws_cacti/src/hri_cacti_xr/gesture_recognition/g

In [22]:
# convert yaml params to dot notation
config = DotMap(config)
print(config.pretrain)

/home/regal/devel/ws_cacti/src/hri_cacti_xr/gesture_recognition/gesture_recogition_research/ActionCLIP/weights/vit-b-16-32f/vit-b-16-32f.pt


In [23]:
Path(working_dir).mkdir(parents=True, exist_ok=True)
shutil.copy(args_config, working_dir)
shutil.copy('test.py', working_dir)

'./exp/clip_k400/ViT-B/16/hl/20240229_154250/test.py'

In [24]:
device = "cuda" if torch.cuda.is_available() else "cpu"  # If using GPU then use mixed precision training.
print(device)

cuda


In [25]:
model, clip_state_dict = clip.load(config.network.arch, device=device, jit=False, tsm=config.network.tsm,
                                                T=config.data.num_segments, dropout=config.network.drop_out,
                                                emb_dropout=config.network.emb_dropout)  # Must set jit=False for training  ViT-B/32

print(model)

dropout used:[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
dropout used:[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
loading clip pretrained model!
CLIP(
  (visual): VisualTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)
    (dropout): Dropout(p=0.0, inplace=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (drop_path): Identity()
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (

In [26]:
transform_val = get_augmentation(False, config)

In [27]:
fusion_model = visual_prompt(config.network.sim_header, clip_state_dict, config.data.num_segments)

layer=6


In [28]:
model_text = TextCLIP(model)
model_image = ImageCLIP(model)

In [29]:
model_text = torch.nn.DataParallel(model_text).cuda()
model_image = torch.nn.DataParallel(model_image).cuda()
fusion_model = torch.nn.DataParallel(fusion_model).cuda()
wandb.watch(model)
wandb.watch(fusion_model)

[]

In [30]:
val_data = Action_DATASETS(config.data.val_list, config.data.label_list, num_segments=config.data.num_segments,
                    image_tmpl=config.data.image_tmpl,
                    transform=transform_val, random_shift=config.random_shift)
val_loader = DataLoader(val_data, batch_size=config.data.batch_size, num_workers=config.data.workers, shuffle=False,
                        pin_memory=True, drop_last=True)
print(config.data.val_list)
print(config.data.label_list)


/home/regal/devel/ws_cacti/src/hri_cacti_xr/gesture_recognition/gesture_recogition_research/action_clip_data/hl_val_frames_re.txt
/home/regal/devel/ws_cacti/src/hri_cacti_xr/gesture_recognition/gesture_recogition_research/action_clip_data/hl_val_labels.csv


In [31]:
if device == "cpu":
    model_text.float()
    model_image.float()
else:
    clip.model.convert_weights(
        model_text)  # Actually this line is unnecessary since clip by default already on float16
    clip.model.convert_weights(model_image)
start_epoch = config.solver.start_epoch

In [32]:
if config.pretrain:
    if os.path.isfile(config.pretrain):
        print(("=> loading checkpoint '{}'".format(config.pretrain)))
        checkpoint = torch.load(config.pretrain)
        model.load_state_dict(checkpoint['model_state_dict'])
        fusion_model.load_state_dict(checkpoint['fusion_model_state_dict'])
        del checkpoint
    else:
        print(("=> no checkpoint found at '{}'".format(config.pretrain)))




=> loading checkpoint '/home/regal/devel/ws_cacti/src/hri_cacti_xr/gesture_recognition/gesture_recogition_research/ActionCLIP/weights/vit-b-16-32f/vit-b-16-32f.pt'


In [35]:
classes, num_text_aug, text_dict = text_prompt(val_data)
best_prec1 = 0.0
print(val_loader)
print(classes.size())
print(num_text_aug)
#prec1 = validate(start_epoch, val_loader, classes, device, model, fusion_model, config, num_text_aug)

<torch.utils.data.dataloader.DataLoader object at 0x7febedd3f100>
torch.Size([48, 77])
16
