# import 

In [None]:
import random
import numpy as np
import torch

def set_seed(seed):
    random.seed(seed) 
    np.random.seed(seed)  
    torch.manual_seed(seed) 
    torch.cuda.manual_seed(seed) 
    torch.cuda.manual_seed_all(seed) 
seed_value = 33
set_seed(seed_value)

In [None]:
import sys
sys.path.append('./')


from model import *
from zero_shot_test import *
from data import *

import yaml

cfg_path = 'config_audio.yaml'
cfg = yaml.safe_load(open(cfg_path))
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

# reprogram

## train data load

In [None]:
# label setup
import sys
sys.path.append('./')

from preprocessing.audioset import *

path = './dataset/audio/audioset/class_labels_indices.csv'

gt_classes = get_gt(path)

csv_file = './dataset/audio/audioset/balanced_train_segments.csv'

root_dir =  "./dataset/audio/audioset/train"

audio_list, video_list = get_audio_n_video_path(root_dir)
train_target = get_target(path,csv_file, audio_list )

In [None]:
from data import AudioSetDataset
from preprocessing.preprocessing_utils import get_img_preprocess_with_rgb

video_process = get_img_preprocess_with_rgb()

train_val_dataset = AudioSetDataset(root_dir, device, video_process, clips_per_video=1,audio_mean=-4.2677393, audio_std= 4.56)


In [None]:
from torch.utils.data import random_split


dataset_size = len(train_val_dataset)
train_size = int(0.9 * dataset_size)  
val_size = dataset_size - train_size  

train_dataset, val_dataset = random_split(train_val_dataset, [train_size, val_size])



In [None]:
train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg['batch_size'],
        shuffle=True,
        drop_last=False,
        pin_memory=False,
        sampler=None
)

val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=cfg['batch_size'],
        shuffle=True,
        drop_last=False,
        pin_memory=False,
        sampler=None
)

In [None]:
for source ,target in train_dataloader:
    print( source.shape, target.shape)
    print(target)
    print(source)
    break

## model_load

In [None]:
import torch
model = torch.load('./trained_model/reprogram_base.pth')

In [None]:

trainable_params = []

for name, param in model.named_parameters():
    try:
        if name.split('.')[1] == 'audio':
            print(f"Parameter: {name}, Requires Grad: {param.requires_grad}")
            trainable_params.append(param)
    except:
        pass

In [None]:
torch.cuda.empty_cache()


## train

In [None]:
model.cuda()

In [None]:
cfg['train_params'] = {}
cfg['train_params']['optimizer'] = 'AdamW'
cfg['train_params']['init_lr'] = 0.001
cfg['train_params']['weight_decay'] = 0.2
cfg['train_params']['scheduler'] = 'cosw'
cfg['train_params']['temperature'] = 0.05
cfg['train_params']['T_max'] = 60

In [None]:
import torch.optim as optim
optimizer =  optim.AdamW(trainable_params, lr=cfg['train_params']['init_lr'], weight_decay=cfg['train_params']['weight_decay'])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg['train_params']['T_max'])
# scheduler =  torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.01, patience=10, verbose=True)

In [None]:

save_best_path = './trained_model/audio/audioset_best.pth'

In [None]:
from utils import *

num_epochs = 1000
early_stop = 0
min_loss = np.inf
print(save_best_path)
print('with final logit')
temperature = cfg['train_params']['temperature']
log = []

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    curr_lr = float(optimizer.param_groups[0]['lr'])
    
    for i, (source, target) in enumerate(train_dataloader):
        
        
        target_inputs = {
            'audio': target.to(device)
        }
        outputs = model(target_inputs,source.to(device))
        
        # loss
        source_features = outputs['source_' + cfg['source_type']]
        target_features = outputs['audio']

        # normalized features
        source_features = source_features / source_features.norm(dim=1, keepdim=True)
        target_features = target_features / target_features.norm(dim=1, keepdim=True)

        # cosine similarity as logits
        logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / temperature)).exp()
        # mid_logit = model.mid_logit
        logits =  logit_scale * source_features @ target_features.t()


        loss_f = cosine_similarity_loss(logits)


        loss = loss_f 
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        total_loss += loss.item()
    
    

    model.eval()  
    total_val_loss = 0.0
    for source, target in val_dataloader:

        target_inputs = {
            'audio': target.to(device)
        }
        with torch.no_grad():
            outputs = model(target_inputs, source.to(device))
            
        # loss
        source_features = outputs['source_' + cfg['source_type']]
        target_features = outputs['audio']

        # normalized features
        source_features = source_features / source_features.norm(dim=1, keepdim=True)
        target_features = target_features / target_features.norm(dim=1, keepdim=True)

        # cosine similarity as logits

        logits =logit_scale * source_features @ target_features.t()

        loss_f = cosine_similarity_loss(logits)

        loss =loss_f 
        total_val_loss += loss.item()
    
    print(f"Epoch {epoch + 1}/{num_epochs}, Curr_LR: {curr_lr}, Train Loss: {total_loss / len(train_dataloader):.4f}, Val Loss: {total_val_loss / len(val_dataloader):.4f}")
    log.append([f"Epoch {epoch + 1}/{num_epochs}, Curr_LR: {curr_lr}, Train Loss: {total_loss / len(train_dataloader):.4f}, Val Loss: {total_val_loss / len(val_dataloader):.4f}"])


    
    if total_val_loss / len(val_dataloader) < min_loss:
        min_loss = total_val_loss / len(val_dataloader)
        early_stop = 0
        torch.save(model, save_best_path)
        print('saved best')
    else:
        early_stop += 1
    
    if early_stop > 3:
        print("Early stopping triggered")
        break
    
    if cfg['train_params']['scheduler'] == 'cosw':
        scheduler.step()
    elif cfg['train_params']['scheduler'] =='plateau':
        scheduler.step(total_val_loss / len(val_dataloader))




In [None]:
model_num = 2
with open(f'./trained_model/log/{model_num}.txt','w') as f:
    for l in log:
        f.writelines(str(l) + '\n')
    
    f.write(str(cfg))

# test
For zero-shot classification get https://github.com/facebookresearch/ImageBind/blob/main/imagebind/bpe/bpe_simple_vocab_16e6.txt.gz 

In [None]:
import torch

model = torch.load("./trained_model/reprogram_trained.pth")
model.eval()

## audioset

In [None]:
from preprocessing.audioset import *


audio_dir = './dataset/audio/audioset/test'
audio_list = get_audio_pathes(audio_dir)

class_label_path = './dataset/audio/audioset/class_labels_indices.csv'
csv_file = './dataset/audio/audioset/eval_segments.csv'

gt_classes = get_gt(class_label_path)

target_list = get_target(class_label_path, csv_file, audio_list)

In [None]:
from data import AudioSetDataset_audio
import torch

root_dir =  "./dataset/audio/audioset/test"
test_dataset = AudioSetDataset_audio(root_dir, device, clips_per_video=5,audio_mean=-4.2677393, audio_std= 4.56)

test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=cfg['batch_size'],
        shuffle=False,
        drop_last=False,
        pin_memory=False,
        sampler=None
)

In [None]:
text_templates = [
    'a bad photo of a {}.',
    'a photo of many {}.',
    'a sculpture of a {}.',
    'a photo of the hard to see {}.',
    'a low resolution photo of the {}.',
    'a rendering of a {}.',
    'graffiti of a {}.',
    'a bad photo of the {}.',
    'a cropped photo of the {}.',
    'a tattoo of a {}.',
    'the embroidered {}.',
    'a photo of a hard to see {}.',
    'a bright photo of a {}.',
    'a photo of a clean {}.',
    'a photo of a dirty {}.',
    'a dark photo of the {}.',
    'a drawing of a {}.',
    'a photo of my {}.',
    'the plastic {}.',
    'a photo of the cool {}.',
    'a close-up photo of a {}.',
    'a black and white photo of the {}.',
    'a painting of the {}.',
    'a painting of a {}.',
    'a pixelated photo of the {}.',
    'a sculpture of the {}.',
    'a bright photo of the {}.',
    'a cropped photo of a {}.',
    'a plastic {}.',
    'a photo of the dirty {}.',
    'a jpeg corrupted photo of a {}.',
    'a blurry photo of the {}.',
    'a photo of the {}.',
    'a good photo of the {}.',
    'a rendering of the {}.',
    'a {} in a video game.',
    'a photo of one {}.',
    'a doodle of a {}.',
    'a close-up photo of the {}.',
    'a photo of a {}.',
    'the origami {}.',
    'the {} in a video game.',
    'a sketch of a {}.',
    'a doodle of the {}.',
    'a origami {}.',
    'a low resolution photo of a {}.',
    'the toy {}.',
    'a rendition of the {}.',
    'a photo of the clean {}.',
    'a photo of a large {}.',
    'a rendition of a {}.',
    'a photo of a nice {}.',
    'a photo of a weird {}.',
    'a blurry photo of a {}.',
    'a cartoon {}.',
    'art of a {}.',
    'a sketch of the {}.',
    'a embroidered {}.',
    'a pixelated photo of a {}.',
    'itap of the {}.',
    'a jpeg corrupted photo of the {}.',
    'a good photo of a {}.',
    'a plushie {}.',
    'a photo of the nice {}.',
    'a photo of the small {}.',
    'a photo of the weird {}.',
    'the cartoon {}.',
    'art of the {}.',
    'a drawing of the {}.',
    'a photo of the large {}.',
    'a black and white photo of a {}.',
    'the plushie {}.',
    'a dark photo of a {}.',
    'itap of a {}.',
    'graffiti of the {}.',
    'a toy {}.',
    'itap of my {}.',
    'a photo of a cool {}.',
    'a photo of a small {}.',
    'a tattoo of the {}.',
]


In [None]:
bpe_path = './ImageBind/imagebind/bpe/bpe_simple_vocab_16e6.txt.gz'
zeroshot_weights = zeroshot_classifier(model.cuda(), gt_classes, text_templates, bpe_path, device)
# zeroshot_weights = zeroshot_weights.to(torch.float32)

In [None]:

all_emb = []
for target in tqdm(test_dataloader):
    with torch.no_grad():
        tmp_audio_features = model.forward({'audio': target.to(device)})
        all_emb.append(tmp_audio_features['audio'])
test_audio_features = torch.concat(all_emb)
audio_features_norm = test_audio_features / test_audio_features.norm(dim=-1, keepdim=True)


In [None]:

test_audio_features = torch.concat(all_emb)
audio_features_norm = test_audio_features / test_audio_features.norm(dim=-1, keepdim=True)


In [None]:
logits = 100. * audio_features_norm @ zeroshot_weights.T
map_multilabel_list(logits.cpu(), target_list)*100

## esc

In [None]:
# label setup
import sys
sys.path.append('./')

from preprocessing.esc50 import *

path = './dataset/audio/esc50/ESC-50-master/meta/esc50.csv'

gt_classes = get_gt(path)

root_dir =  "./dataset/audio/esc50/ESC-50-master/audio/"

audio_list = get_audio_path(root_dir, path)
test_target = get_target(path)

In [None]:
from data import Esc50Dataset


test_dataset = Esc50Dataset(root_dir, path, 'cpu', clips_per_video=5,audio_mean=-4.2677393, audio_std= 4.56)

test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=2,
        shuffle=False,
        drop_last=False,
        pin_memory=False,
        sampler=None
)

In [None]:
bpe_path = './ImageBind/imagebind/bpe/bpe_simple_vocab_16e6.txt.gz'
zeroshot_weights = zeroshot_classifier(model.cuda(), gt_classes, text_templates, bpe_path, device)
# zeroshot_weights = zeroshot_weights.to(torch.float32)

In [None]:


all_emb = []
for target in tqdm(test_dataloader):
    with torch.no_grad():
        tmp_audio_features = model.forward({'audio': target.to(device)})
        all_emb.append(tmp_audio_features['audio'])
test_audio_features = torch.concat(all_emb)
audio_features_norm = test_audio_features / test_audio_features.norm(dim=-1, keepdim=True)


In [None]:
logits = 100. * audio_features_norm @ zeroshot_weights.T
top1, top5 = top1_top5_acc(logits,test_target.cuda())
print(top1, top5)

# imagebind

In [None]:
# import numpy as np
import torch
from tqdm import tqdm
from pkg_resources import packaging

import sys
sys.path.append('./ImageBind/')

from imagebind import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType

import os

In [None]:
# import model
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)


In [None]:
import gc

# 가비지 컬렉터 호출
gc.collect()


## audioset

In [None]:
from preprocessing.audioset import *


audio_dir = './dataset/audio/audioset/test'
audio_list = get_audio_pathes(audio_dir)

class_label_path = './dataset/audio/audioset/class_labels_indices.csv'
csv_file = './dataset/audio/audioset/eval_segments.csv'

gt_classes = get_gt(class_label_path)

target_list = get_target(class_label_path, csv_file, audio_list)

In [None]:
text_templates = [
    'a bad photo of a {}.',
    'a photo of many {}.',
    'a sculpture of a {}.',
    'a photo of the hard to see {}.',
    'a low resolution photo of the {}.',
    'a rendering of a {}.',
    'graffiti of a {}.',
    'a bad photo of the {}.',
    'a cropped photo of the {}.',
    'a tattoo of a {}.',
    'the embroidered {}.',
    'a photo of a hard to see {}.',
    'a bright photo of a {}.',
    'a photo of a clean {}.',
    'a photo of a dirty {}.',
    'a dark photo of the {}.',
    'a drawing of a {}.',
    'a photo of my {}.',
    'the plastic {}.',
    'a photo of the cool {}.',
    'a close-up photo of a {}.',
    'a black and white photo of the {}.',
    'a painting of the {}.',
    'a painting of a {}.',
    'a pixelated photo of the {}.',
    'a sculpture of the {}.',
    'a bright photo of the {}.',
    'a cropped photo of a {}.',
    'a plastic {}.',
    'a photo of the dirty {}.',
    'a jpeg corrupted photo of a {}.',
    'a blurry photo of the {}.',
    'a photo of the {}.',
    'a good photo of the {}.',
    'a rendering of the {}.',
    'a {} in a video game.',
    'a photo of one {}.',
    'a doodle of a {}.',
    'a close-up photo of the {}.',
    'a photo of a {}.',
    'the origami {}.',
    'the {} in a video game.',
    'a sketch of a {}.',
    'a doodle of the {}.',
    'a origami {}.',
    'a low resolution photo of a {}.',
    'the toy {}.',
    'a rendition of the {}.',
    'a photo of the clean {}.',
    'a photo of a large {}.',
    'a rendition of a {}.',
    'a photo of a nice {}.',
    'a photo of a weird {}.',
    'a blurry photo of a {}.',
    'a cartoon {}.',
    'art of a {}.',
    'a sketch of the {}.',
    'a embroidered {}.',
    'a pixelated photo of a {}.',
    'itap of the {}.',
    'a jpeg corrupted photo of the {}.',
    'a good photo of a {}.',
    'a plushie {}.',
    'a photo of the nice {}.',
    'a photo of the small {}.',
    'a photo of the weird {}.',
    'the cartoon {}.',
    'art of the {}.',
    'a drawing of the {}.',
    'a photo of the large {}.',
    'a black and white photo of a {}.',
    'the plushie {}.',
    'a dark photo of a {}.',
    'itap of a {}.',
    'graffiti of the {}.',
    'a toy {}.',
    'itap of my {}.',
    'a photo of a cool {}.',
    'a photo of a small {}.',
    'a tattoo of the {}.',
]


In [None]:
# get zeroshot weights

zeroshot_weights = []
for classname in tqdm(list(gt_classes)):
    text_list =  [template.format(classname) for template in text_templates]

    inputs = {
        ModalityType.TEXT: data.load_and_transform_text(text_list, device),
    }

    with torch.no_grad():
        embeddings = model(inputs)


        embeddings[ModalityType.TEXT] /= embeddings[ModalityType.TEXT].norm(dim=-1, keepdim=True)
        embeddings[ModalityType.TEXT] = embeddings[ModalityType.TEXT].mean(dim=0)
        embeddings[ModalityType.TEXT] /= embeddings[ModalityType.TEXT].norm()
        zeroshot_weights.append(embeddings[ModalityType.TEXT])
    # break

zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()

In [None]:


all_emb = []
for i in range(0,len(audio_list),10):
    

      tmp_list = audio_list[i:i+10]


      inputs = {
          ModalityType.AUDIO: data.load_and_transform_audio_data(tmp_list, device, mean = -4.2677393, std = 4.56),
      }

      with torch.no_grad():
          imembeddings = model(inputs)

          all_emb.append(imembeddings[ModalityType.AUDIO] )

audio_emb = torch.concat(all_emb)
audio_emb /= audio_emb.norm(dim=-1, keepdim=True)

In [None]:
logits = 100. * audio_emb @ zeroshot_weights
map_multilabel_list(logits.cpu(), target_list)*100

## esc

In [None]:
# label setup
import sys
sys.path.append('./')

from preprocessing.esc50 import *

path = './dataset/audio/esc50/ESC-50-master/meta/esc50.csv'

gt_classes = get_gt(path)

root_dir =  "./dataset/audio/esc50/ESC-50-master/audio/"

audio_list = get_audio_path(root_dir, path)
test_target = get_target(path)

In [None]:
# get zeroshot weights

zeroshot_weights = []
for classname in tqdm(list(gt_classes)):
    text_list =  [template.format(classname) for template in text_templates]

    inputs = {
        ModalityType.TEXT: data.load_and_transform_text(text_list, device),
    }

    with torch.no_grad():
        embeddings = model(inputs)


        embeddings[ModalityType.TEXT] /= embeddings[ModalityType.TEXT].norm(dim=-1, keepdim=True)
        embeddings[ModalityType.TEXT] = embeddings[ModalityType.TEXT].mean(dim=0)
        embeddings[ModalityType.TEXT] /= embeddings[ModalityType.TEXT].norm()
        zeroshot_weights.append(embeddings[ModalityType.TEXT])
    # break

zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()

In [None]:


all_emb = []
for i in range(0,len(audio_list),10):
    

      tmp_list = audio_list[i:i+10]


      inputs = {
          ModalityType.AUDIO: data.load_and_transform_audio_data(tmp_list, device, mean = -4.2677393, std = 4.56),
      }

      with torch.no_grad():
          imembeddings = model(inputs)

          all_emb.append(imembeddings[ModalityType.AUDIO] )

audio_emb = torch.concat(all_emb)
audio_emb /= audio_emb.norm(dim=-1, keepdim=True)

In [None]:
# zero shot prediction
def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu()) for k in topk]

logits = 100. * audio_emb @ zeroshot_weights

top1, top5, n = 0., 0., 0.

# measure accuracy
acc1, acc5 = accuracy(logits, test_target[:9210].to(device), topk=(1, 5))
n += test_target.size(0)

top1 = (acc1 / n) * 100
top5 = (acc5 / n) * 100

print(f"Top-1 accuracy: {top1:.2f}")
print(f"Top-5 accuracy: {top5:.2f}")