In [3]:
import pickle
import torch

from sklearn import metrics
from tqdm import tqdm
import sys
sys.path.append('..')

import corect

In [4]:
log = corect.utils.get_logger()

def load_pkl(file):
    with open(file, "rb") as f:
        return pickle.load(f)

In [10]:
import os
current_directory = os.getcwd()
print(current_directory)
# parent_directory = os.path.dirname(current_directory)
# os.chdir(parent_directory)

/home/USER/aisafe_back


In [11]:
from sentence_transformers import SentenceTransformer
import corect


log = corect.utils.get_logger()
sbert_model = SentenceTransformer("paraphrase-distilroberta-base-v1")

corect.utils.set_seed(900)


(
    video_ids,
    video_speakers,
    video_labels,
    video_text,
    video_audio,
    video_visual,
    video_sentence,
    trainVids,
    test_vids,
) = pickle.load(
    open("models/corect_feat_iemocap.pkl", "rb"), encoding="latin1"
)

train, dev, test = [], [], []
dev_size = int(len(trainVids) * 0.1)
train_vids, dev_vids = trainVids[dev_size:], trainVids[:dev_size]

for vid in tqdm(train_vids, desc="train"):
    train.append(
        {
            "vid" : vid,
            "speakers" : video_speakers[vid],
            "labels" : video_labels[vid],
            "audio" : video_audio[vid],
            "visual" : video_visual[vid],
            "text": sbert_model.encode(video_sentence[vid]),
            "sentence" : video_sentence[vid],
        }
    )
for vid in tqdm(dev_vids, desc="dev"):
    dev.append(
        {
            "vid" : vid,
            "speakers" : video_speakers[vid],
            "labels" : video_labels[vid],
            "audio" : video_audio[vid],
            "visual" : video_visual[vid],
            "text": sbert_model.encode(video_sentence[vid]),
            "sentence" : video_sentence[vid],
        }
    )
for vid in tqdm(test_vids, desc="test"):
    test.append(
        {
            "vid" : vid,
            "speakers" : video_speakers[vid],
            "labels" : video_labels[vid],
            "audio" : video_audio[vid],
            "visual" : video_visual[vid],
            "text": sbert_model.encode(video_sentence[vid]),
            "sentence" : video_sentence[vid],
        }
    )

log.info("train vids:")
log.info(sorted(train_vids))
log.info("dev vids:")
log.info(sorted(dev_vids))
log.info("test vids:")
log.info(sorted(test_vids))




Seed set 900


train: 100%|██████████| 108/108 [05:36<00:00,  3.11s/it]
dev: 100%|██████████| 12/12 [00:37<00:00,  3.10s/it]
test: 100%|██████████| 31/31 [02:11<00:00,  4.23s/it]

09/29/2024 11:49:43 train vids:
09/29/2024 11:49:43 ['Ses01F_impro01', 'Ses01F_impro02', 'Ses01F_impro03', 'Ses01F_impro04', 'Ses01F_impro05', 'Ses01F_impro06', 'Ses01F_impro07', 'Ses01F_script01_1', 'Ses01F_script01_2', 'Ses01F_script01_3', 'Ses01F_script02_1', 'Ses01F_script02_2', 'Ses01F_script03_1', 'Ses01F_script03_2', 'Ses01M_impro01', 'Ses01M_impro02', 'Ses01M_impro04', 'Ses01M_impro05', 'Ses01M_impro06', 'Ses01M_impro07', 'Ses01M_script01_1', 'Ses01M_script01_2', 'Ses01M_script01_3', 'Ses01M_script02_1', 'Ses01M_script03_1', 'Ses01M_script03_2', 'Ses02F_impro01', 'Ses02F_impro02', 'Ses02F_impro03', 'Ses02F_impro04', 'Ses02F_impro05', 'Ses02F_impro06', 'Ses02F_impro07', 'Ses02F_impro08', 'Ses02F_script01_2', 'Ses02F_script02_1', 'Ses02F_script02_2', 'Ses02F_script03_1', 'Ses02F_script03_2', 'Ses02M_impro01', 'Ses02M_impro03', 'Ses02M_impro04', 'Ses02M_impro05', 'Ses02M_impro06', 'Ses02M_script01_1', 'Ses02M_script01_2', 'Ses02M_script01_3', 'Ses02M_script02_1', 'Ses02M_script03_




In [None]:
from comet_ml import Experiment, Optimizer

import torch
import os
import corect

log = corect.utils.get_logger()
data = load_pkl(f"model/corect_data_iemocap.pkl")

trainset = corect.Dataset(data["train"], args)
devset = corect.Dataset(data["dev"], args)
testset = corect.Dataset(data["test"], args)

log.debug("Building model...")
    
    model_file = args.data_root + "/model_checkpoints/model.pt"
    model = corect.CORECT(args).to(args.device)
    opt = corect.Optim(args.learning_rate, args.max_grad_value, args.weight_decay)
    opt.set_parameters(model.parameters(), args.optimizer)
    sched = opt.get_scheduler(args.scheduler)

    coach = corect.Coach(trainset, devset, testset, model, opt, sched, args)
    if not args.from_begin:
        ckpt = torch.load(model_file)
        coach.load_ckpt(ckpt)
        print("Training from checkpoint...")

    # Train
    log.info("Start training...")
    ret = coach.train()

    # Save.
    checkpoint = {
        "best_dev_f1": ret[0],
        "best_epoch": ret[1],
        "best_state": ret[2],
    }

    torch.save(checkpoint, model_file)


In [None]:
def main():
    data = load_pkl(f"model/corect_data_iemocap.pkl")
    model_dict = torch.load(
        "models/model_checkpoints/"
        + "MELD"
        + "_best_dev_f1_model_"
        + "atv"
        + ".pt",
    )
    stored_args = model_dict["args"]
    model = model_dict["state_dict"]
    testset = corect.Dataset(data["test"], stored_args)

    test = True
    with torch.no_grad():
        golds = []
        preds = []
        for idx in tqdm(range(len(testset)), desc="test" if test else "dev"):
            data = testset[idx]
            golds.append(data["label_tensor"])
            for k, v in data.items():
                if not k == "utterance_texts":
                    data[k] = v.to(stored_args.device)
            y_hat = model(data)

            preds.append(y_hat.detach().to("cpu"))

        golds = torch.cat(golds, dim=-1).numpy()
        preds = torch.cat(preds, dim=-1).numpy()
        f1 = metrics.f1_score(golds, preds, average="weighted")

        if test:
            print(metrics.classification_report(golds, preds, digits=4))
            print(f"F1 Score: {f1}")


if __name__ == "__main__":
    main(args)
