# Import Library

In [1]:
from utils import get_config
from transformer_model import TransformerModel
from model import linear_model, lstm_model, resnet_model, ImageTransformer, ImageFeatureTransformer
import os
import torch
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# Load the configuration

In [2]:
# Load the configuration
cfg = get_config('abaw/expr/vis_transformer_seq_4/config.yaml')  # Update this path as necessary

# Load Model and weights

In [3]:
if cfg['model'] == 'transformer':
    model = TransformerModel(seq_len=cfg['seq_size'], embedding_size=cfg['embedding_size'],
                                nhead=cfg['n_head'], num_encoder_layers=cfg['n_layers'],
                                num_classes=cfg['num_classes'], cfg=cfg)
elif cfg['model'] == 'linear':
    # Linear model
    model = linear_model(cfg)
elif cfg['model'] == 'lstm':
    # LSTM model
    model = lstm_model(cfg)
elif cfg['model'] == 'resnet':
    # ResNet model
    model = resnet_model(cfg)
elif cfg['model'] == 'vit':
    model = ImageTransformer('google/vit-base-patch16-224-in21k', cfg, device=cfg['device'])
else:
    # Vision Transformer
    model = ImageFeatureTransformer('google/vit-base-patch16-224-in21k', cfg, device=cfg['device'])

In [4]:
# load weights
model.load_state_dict(torch.load('abaw/expr/vis_transformer_seq_4/best_model_15_0.20817634738450513.pt'))

<All keys matched successfully>

# Test Dataset

In [11]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import cv2

class MyDataset_test(Dataset):
    def __init__(self, data_root,
                 txt_path='/home/minseongjae/ABAW/0_data/test/test_set_examples/test_set_examples/CVPR_6th_ABAW_Expr_test_set_example.txt', 
                 return_vis=True, return_aud=False, return_img=False, return_seq=False, seq_size=10):
        self.data_root = data_root
        self.seq_size = seq_size
        with open(txt_path, 'r') as f:
            lines = f.readlines()
        self.img_paths = [x.split(',')[0] for x in lines[1:]]
        self.directories = [x.split('/')[0] for x in self.img_paths]
        self.return_seq = return_seq
        self.return_vis = return_vis
        if return_vis:
            self.img_paths = [x for x in self.img_paths if os.path.exists(os.path.join(data_root, 'vis_feat', x.replace('.jpg','.npy')))]
        self.return_aud = return_aud
        self.return_img = return_img
        if return_img:            
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor()
            ])        

    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        data = {}
        if self.return_seq:
            start_idx = max(idx - self.seq_size + 1, 0)
            sequence_range = range(start_idx, idx + 1)
            current_dir = self.directories[idx]
            idxs = []
            for i in sequence_range:
                if self.directories[i] != current_dir:
                    continue
                else:
                    idxs.append(i)
            if len(idxs) < self.seq_size:
                idxs = [idxs[0]] * (self.seq_size - len(idxs)) + idxs
            img_paths_seq = [self.img_paths[i] for i in idxs]
            data['img_paths'] = img_paths_seq
            if self.return_img:
                img_seq = [self.transform(image=cv2.cvtColor(cv2.imread(os.path.join(self.data_root, 'imgs', ip)), cv2.COLOR_BGR2RGB))['image'] for ip in img_paths_seq]
                data['img'] = torch.stack(img_seq)
            if self.return_vis:
                vis_feat_seq = np.concatenate([np.load(os.path.join(self.data_root, 'vis_feat', ip.replace('.jpg','.npy'))) for ip in img_paths_seq], axis=0)
                data['vis_feat'] = torch.tensor(vis_feat_seq, dtype=torch.float)
            if self.return_aud:
                aud_feat_seq = np.concatenate([np.load(os.path.join(self.data_root, 'aud_feat', ip.replace('.jpg','.npy'))) for ip in img_paths_seq], axis=0)
                data['aud_feat'] = torch.tensor(aud_feat_seq, dtype=torch.float)

        else:
            img_path = self.img_paths[idx]
            data['img_path'] = img_path            
            if self.return_vis:
                vis_feat = np.load(os.path.join(self.data_root, 'vis_feat', img_path.replace('.jpg','.npy')))
                data['vis_feat'] = torch.tensor(vis_feat, dtype=torch.float)
            if self.return_aud:
                aud_feat = np.load(os.path.join(self.data_root, 'aud_feat', img_path.replace('.jpg','.npy')))
                data['aud_feat'] = torch.tensor(aud_feat, dtype=torch.float)
            if self.return_img:
                img_path = os.path.join(self.data_root, 'imgs', img_path)
                img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
                img = self.transform(image=img)['image']
                data['img'] = img
        return data

In [12]:
test_dataset = MyDataset_test(data_root = cfg['data_path'], txt_path='/home/minseongjae/ABAW/0_data/test/test_set_examples/test_set_examples/CVPR_6th_ABAW_Expr_test_set_example.txt',
                        return_img=cfg['return_img'], return_aud=cfg['return_aud'], return_vis=cfg['return_vis'],
                        return_seq=cfg['return_seq'], seq_size=cfg['seq_size'])

In [13]:
test_loader = DataLoader(test_dataset, batch_size=cfg['batch_size'], shuffle=False, num_workers=cfg['num_workers'])

In [14]:
img_paths = []
outputs = []
model.eval()
model.to(cfg['device'])
for data in tqdm(test_loader):
    if cfg['return_seq']:
        vis_feat = data['vis_feat'].to(cfg['device'])
        output = model(vis_feat)
        img_paths += data['img_paths'][-1]
    else:
        vis_feat = data['vis_feat'].to(cfg['device'])
        output = model(vis_feat)
        img_paths += data['img_path']
    output = output.cpu().detach().numpy()
    output = np.argmax(output, axis=1)
    outputs += output.tolist()

  0%|          | 0/1969 [00:00<?, ?it/s]

100%|██████████| 1969/1969 [23:39<00:00,  1.39it/s]


In [16]:
len(outputs), len(img_paths)

(1007648, 1007648)

In [17]:
results = {img_paths[i]: outputs[i] for i in range(len(img_paths))}

In [20]:
# save
import pandas as pd
df = pd.DataFrame(list(results.items()))
df.to_csv('results.csv', index=False, header=['image_location','Neutral,Anger,Disgust,Fear,Happiness,Sadness,Surprise,Other'])

In [21]:
# 한줄씩 돌면서 image_location이 일치할때는 저장, 아니면 다음줄로 넘어가기
with open('CVPR_6th_ABAW_Expr_test.txt', 'w') as f:
    with open('results.csv', 'r') as r:
        lines = r.readlines()
        f.write(lines[0])
        for line in lines[1:]:
            img_loc = line.split(',')[0]
            if img_loc in open('CVPR_6th_ABAW_Expr_test_set.txt', 'r').read():
                f.write(line)
            else:
                continue

KeyboardInterrupt: 

# Missing Value