In [None]:
# reproduce results

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch, json, pandas as pd, numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm import tqdm

device = torch.device('cuda')

test_data = json.load(open('./figurative_flute/data/test.json', 'r'))

# print(premise)
# print(hypothesis)
# print(fig)
# print(f'{label} | {pred_label}')
# print(f'{expl} | {pred_expl}')

#### System 1

In [None]:
tokenizer = AutoTokenizer.from_pretrained("t5-3b")
model = AutoModelForSeq2SeqLM.from_pretrained("allenai/System1_FigLang2022", torch_dtype=torch.float16).to(device)

prems, hypos, fig_types, tgt_labels, tgt_explanations, pred_labels, pred_explanations = [], [], [], [], [], [], []

for sample in tqdm(test_data):
    premise, hypothesis, fig = sample['premise'], sample['hypothesis'], sample['fig_type']
    label, expl = sample['label'], sample['explanation']
    input_string = f"Premise: {premise} Hypothesis: {hypothesis} " 
    input_string += "Is there a contradiction or entailment between the premise and hypothesis?"
    # print(input_string)
    input_ids = tokenizer.encode(input_string, return_tensors="pt").to(device)
    output = model.generate(input_ids, max_length=200)
    output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    pred_label = output.split('.')[0].split(':')[1].strip()
    pred_expl = output.split('.')[1].split(':')[1].strip()
    prems.append(premise)
    hypos.append(hypothesis)
    fig_types.append(fig)
    tgt_labels.append(label)
    tgt_explanations.append(expl)
    pred_labels.append(pred_label)
    pred_explanations.append(pred_expl)

In [None]:
assert len(prems) == len(hypos) == len(fig_types) == len(pred_labels) == len(tgt_labels) == len(pred_explanations) == len(tgt_explanations)
print(len(prems))

cols = ['premise', 'hypothesis', 'type', 'pred_label', 'ref_label', 'pred_explanation', 'ref_explanation']
df = pd.DataFrame(list(zip(prems, hypos, fig_types, pred_labels, tgt_labels, pred_explanations, tgt_explanations)), columns=cols)

path = './figurative_flute/data/outputs/dream/'
df.to_csv(f'{path}sys1_outputs.csv', header=True, index=False)

#### System 2

In [None]:
tokenizer = AutoTokenizer.from_pretrained("t5-3b")
model = AutoModelForSeq2SeqLM.from_pretrained("allenai/System2_FigLang2022", torch_dtype=torch.float16).to(device)

prems, hypos, fig_types, tgt_labels, tgt_explanations, pred_labels, pred_explanations = [], [], [], [], [], [], []

for sample in tqdm(test_data):
    premise, hypothesis, fig = sample['premise'], sample['hypothesis'], sample['fig_type']
    label, expl = sample['label'], sample['explanation']
    input_string = f"Premise: {premise} Hypothesis: {hypothesis} " 
    input_string += "What is the type of figurative language involved? Is there a contradiction or entailment between the premise and hypothesis?"
    #print(input_string)
    input_ids = tokenizer.encode(input_string, return_tensors="pt").to(device)
    output = model.generate(input_ids, max_length=200)
    output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    pred_label = output.split('.')[0].split(']')[2].strip()
    pred_expl = output.split('.')[1].split(':')[1].strip()
    prems.append(premise)
    hypos.append(hypothesis)
    fig_types.append(fig)
    tgt_labels.append(label)
    tgt_explanations.append(expl)
    pred_labels.append(pred_label)
    pred_explanations.append(pred_expl)

In [None]:
assert len(prems) == len(hypos) == len(fig_types) == len(pred_labels) == len(tgt_labels) == len(pred_explanations) == len(tgt_explanations)
print(len(prems))

cols = ['premise', 'hypothesis', 'type', 'pred_label', 'ref_label', 'pred_explanation', 'ref_explanation']
df = pd.DataFrame(list(zip(prems, hypos, fig_types, pred_labels, tgt_labels, pred_explanations, tgt_explanations)), columns=cols)

path = './figurative_flute/data/outputs/dream/'
df.to_csv(f'{path}sys2_outputs.csv', header=True, index=False)

#### System 3

In [None]:
def find_dream_scene(premise, hypothesis, scene_type, dream_data):
    for dream_sample in dream_data:
        if dream_sample['premise'] == premise and dream_sample['hypothesis'] == hypothesis:
            return dream_sample[f'premise_{scene_type}'], dream_sample[f'hypothesis_{scene_type}']
    return '', ''

##### emotion

In [None]:
test_dream_data = json.load(open('./figurative_flute/data/dream/test_dream.json', 'r'))

tokenizer = AutoTokenizer.from_pretrained("t5-3b")
model = AutoModelForSeq2SeqLM.from_pretrained("allenai/System3_DREAM_FLUTE_emotion_FigLang2022", torch_dtype=torch.float16).to(device)

prems, hypos, fig_types, tgt_labels, tgt_explanations, pred_labels, pred_explanations = [], [], [], [], [], [], []

for sample in tqdm(test_data):
    premise, hypothesis, fig = sample['premise'], sample['hypothesis'], sample['fig_type']
    label, expl = sample['label'], sample['explanation']
    premise_dream, hypothesis_dream = find_dream_scene(premise, hypothesis, 'emotion', test_dream_data)
    input_string = f"Premise: {premise} [Premise - emotion] {premise_dream} Hypothesis: {hypothesis} [Hypothesis - emotion] {hypothesis_dream} "
    input_string += "Is there a contradiction or entailment between the premise and hypothesis?"
    # print(input_string)
    input_ids = tokenizer.encode(input_string, return_tensors="pt").to(device)
    output = model.generate(input_ids, max_length=200)
    output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    pred_label = output.split('.')[0].split(':')[1].strip()
    pred_expl = output.split('.')[1].split(':')[1].strip()
    prems.append(premise)
    hypos.append(hypothesis)
    fig_types.append(fig)
    tgt_labels.append(label)
    tgt_explanations.append(expl)
    pred_labels.append(pred_label)
    pred_explanations.append(pred_expl)

In [None]:
assert len(prems) == len(hypos) == len(fig_types) == len(pred_labels) == len(tgt_labels) == len(pred_explanations) == len(tgt_explanations)
print(len(prems))

cols = ['premise', 'hypothesis', 'type', 'pred_label', 'ref_label', 'pred_explanation', 'ref_explanation']
df = pd.DataFrame(list(zip(prems, hypos, fig_types, pred_labels, tgt_labels, pred_explanations, tgt_explanations)), columns=cols)

path = './figurative_flute/data/outputs/dream/'
df.to_csv(f'{path}sys3_emotion_outputs.csv', header=True, index=False)

##### motivation

In [None]:
test_dream_data = json.load(open('./figurative_flute/data/dream/test_dream.json', 'r'))

tokenizer = AutoTokenizer.from_pretrained("t5-3b")
model = AutoModelForSeq2SeqLM.from_pretrained("allenai/System3_DREAM_FLUTE_motivation_FigLang2022", torch_dtype=torch.float16).to(device)

prems, hypos, fig_types, tgt_labels, tgt_explanations, pred_labels, pred_explanations = [], [], [], [], [], [], []

for sample in tqdm(test_data):
    premise, hypothesis, fig = sample['premise'], sample['hypothesis'], sample['fig_type']
    label, expl = sample['label'], sample['explanation']
    premise_dream, hypothesis_dream = find_dream_scene(premise, hypothesis, 'motivation', test_dream_data)
    input_string = f"Premise: {premise} [Premise - motivation] {premise_dream} Hypothesis: {hypothesis} [Hypothesis - motivation] {hypothesis_dream} "
    input_string += "Is there a contradiction or entailment between the premise and hypothesis?"
    # print(input_string)
    input_ids = tokenizer.encode(input_string, return_tensors="pt").to(device)
    output = model.generate(input_ids, max_length=200)
    output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    # print(output)
    pred_label = output.split('.')[0].split(':')[1].strip()
    pred_expl = output.split('.')[1].split(':')[1].strip()
    prems.append(premise)
    hypos.append(hypothesis)
    fig_types.append(fig)
    tgt_labels.append(label)
    tgt_explanations.append(expl)
    pred_labels.append(pred_label)
    pred_explanations.append(pred_expl)

In [None]:
assert len(prems) == len(hypos) == len(fig_types) == len(pred_labels) == len(tgt_labels) == len(pred_explanations) == len(tgt_explanations)
print(len(prems))

cols = ['premise', 'hypothesis', 'type', 'pred_label', 'ref_label', 'pred_explanation', 'ref_explanation']
df = pd.DataFrame(list(zip(prems, hypos, fig_types, pred_labels, tgt_labels, pred_explanations, tgt_explanations)), columns=cols)

path = './figurative_flute/data/outputs/dream/'
df.to_csv(f'{path}sys3_motivation_outputs.csv', header=True, index=False)

##### consequence

In [None]:
test_dream_data = json.load(open('./figurative_flute/data/dream/test_dream.json', 'r'))

tokenizer = AutoTokenizer.from_pretrained("t5-3b")
model = AutoModelForSeq2SeqLM.from_pretrained("allenai/System3_DREAM_FLUTE_consequence_FigLang2022", torch_dtype=torch.float16).to(device)

prems, hypos, fig_types, tgt_labels, tgt_explanations, pred_labels, pred_explanations = [], [], [], [], [], [], []

for sample in tqdm(test_data):
    premise, hypothesis, fig = sample['premise'], sample['hypothesis'], sample['fig_type']
    label, expl = sample['label'], sample['explanation']
    premise_dream, hypothesis_dream = find_dream_scene(premise, hypothesis, 'consequence', test_dream_data)
    input_string = f"Premise: {premise} [Premise - likely consequence] {premise_dream} Hypothesis: {hypothesis} [Hypothesis - likely consequence] {hypothesis_dream} "
    input_string += "Is there a contradiction or entailment between the premise and hypothesis?"
    # print(input_string)
    input_ids = tokenizer.encode(input_string, return_tensors="pt").to(device)
    output = model.generate(input_ids, max_length=200)
    output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    # print(output)
    pred_label = output.split('.')[0].split(':')[1].strip()
    pred_expl = output.split('.')[1].split(':')[1].strip()
    prems.append(premise)
    hypos.append(hypothesis)
    fig_types.append(fig)
    tgt_labels.append(label)
    tgt_explanations.append(expl)
    pred_labels.append(pred_label)
    pred_explanations.append(pred_expl)

In [None]:
assert len(prems) == len(hypos) == len(fig_types) == len(pred_labels) == len(tgt_labels) == len(pred_explanations) == len(tgt_explanations)
print(len(prems))

cols = ['premise', 'hypothesis', 'type', 'pred_label', 'ref_label', 'pred_explanation', 'ref_explanation']
df = pd.DataFrame(list(zip(prems, hypos, fig_types, pred_labels, tgt_labels, pred_explanations, tgt_explanations)), columns=cols)

path = './figurative_flute/data/outputs/dream/'
df.to_csv(f'{path}sys3_consequence_outputs.csv', header=True, index=False)

##### social norm

In [None]:
test_dream_data = json.load(open('./figurative_flute/data/dream/test_dream.json', 'r'))

tokenizer = AutoTokenizer.from_pretrained("t5-3b")
model = AutoModelForSeq2SeqLM.from_pretrained("allenai/System3_DREAM_FLUTE_social_norm_FigLang2022", torch_dtype=torch.float16).to(device)

prems, hypos, fig_types, tgt_labels, tgt_explanations, pred_labels, pred_explanations = [], [], [], [], [], [], []

for sample in tqdm(test_data):
    premise, hypothesis, fig = sample['premise'], sample['hypothesis'], sample['fig_type']
    label, expl = sample['label'], sample['explanation']
    premise_dream, hypothesis_dream = find_dream_scene(premise, hypothesis, 'rot', test_dream_data)
    input_string = f"Premise: {premise} [Premise - social norm] {premise_dream} Hypothesis: {hypothesis} [Hypothesis - social norm] {hypothesis_dream} "
    input_string += "Is there a contradiction or entailment between the premise and hypothesis?"
    # print(input_string)
    input_ids = tokenizer.encode(input_string, return_tensors="pt").to(device)
    output = model.generate(input_ids, max_length=200)
    output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    # print(output)
    pred_label = output.split('.')[0].split(':')[1].strip()
    pred_expl = output.split('.')[1].split(':')[1].strip()
    prems.append(premise)
    hypos.append(hypothesis)
    fig_types.append(fig)
    tgt_labels.append(label)
    tgt_explanations.append(expl)
    pred_labels.append(pred_label)
    pred_explanations.append(pred_expl)

In [None]:
assert len(prems) == len(hypos) == len(fig_types) == len(pred_labels) == len(tgt_labels) == len(pred_explanations) == len(tgt_explanations)
print(len(prems))

cols = ['premise', 'hypothesis', 'type', 'pred_label', 'ref_label', 'pred_explanation', 'ref_explanation']
df = pd.DataFrame(list(zip(prems, hypos, fig_types, pred_labels, tgt_labels, pred_explanations, tgt_explanations)), columns=cols)

path = './figurative_flute/data/outputs/dream/'
df.to_csv(f'{path}sys3_rot_outputs.csv', header=True, index=False)

##### all 4 dimensions

In [None]:
test_dream_data = json.load(open('./figurative_flute/data/dream/test_dream.json', 'r'))

tokenizer = AutoTokenizer.from_pretrained("t5-3b")
model = AutoModelForSeq2SeqLM.from_pretrained("allenai/System3_DREAM_FLUTE_all_dimensions_FigLang2022", torch_dtype=torch.float16).to(device)

prems, hypos, fig_types, tgt_labels, tgt_explanations, pred_labels, pred_explanations = [], [], [], [], [], [], []

for sample in tqdm(test_data):
    premise, hypothesis, fig = sample['premise'], sample['hypothesis'], sample['fig_type']
    label, expl = sample['label'], sample['explanation']
    premise_dream_emotion, hypothesis_dream_emotion = find_dream_scene(premise, hypothesis, 'emotion', test_dream_data)
    premise_dream_motivation, hypothesis_dream_motivation = find_dream_scene(premise, hypothesis, 'motivation', test_dream_data)
    premise_dream_consequence, hypothesis_dream_consequence = find_dream_scene(premise, hypothesis, 'consequence', test_dream_data)
    premise_dream_social, hypothesis_dream_social = find_dream_scene(premise, hypothesis, 'rot', test_dream_data)
    input_string = f"Premise: {premise} [Premise - social norm] {premise_dream_social} [Premise - emotion] {premise_dream_emotion} [Premise - motivation] {premise_dream_motivation} [Premise - likely consequence] {premise_dream_consequence} "
    input_string += f"Hypothesis: {hypothesis} [Hypothesis - social norm] {hypothesis_dream_social} [Hypothesis - emotion] {hypothesis_dream_emotion} [Hypothesis - motivation] {hypothesis_dream_motivation} [Hypothesis - likely consequence] {hypothesis_dream_consequence} "
    input_string += "Is there a contradiction or entailment between the premise and hypothesis?"
    # print(input_string)
    input_ids = tokenizer.encode(input_string, return_tensors="pt").to(device)
    output = model.generate(input_ids, max_length=200)
    output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    # print(output)
    pred_label = output.split('.')[0].split(':')[1].strip()
    pred_expl = output.split('.')[1].split(':')[1].strip()
    prems.append(premise)
    hypos.append(hypothesis)
    fig_types.append(fig)
    tgt_labels.append(label)
    tgt_explanations.append(expl)
    pred_labels.append(pred_label)
    pred_explanations.append(pred_expl)

In [None]:
assert len(prems) == len(hypos) == len(fig_types) == len(pred_labels) == len(tgt_labels) == len(pred_explanations) == len(tgt_explanations)
print(len(prems))

cols = ['premise', 'hypothesis', 'type', 'pred_label', 'ref_label', 'pred_explanation', 'ref_explanation']
df = pd.DataFrame(list(zip(prems, hypos, fig_types, pred_labels, tgt_labels, pred_explanations, tgt_explanations)), columns=cols)

path = './figurative_flute/data/outputs/dream/'
df.to_csv(f'{path}sys3_4dim_outputs.csv', header=True, index=False)

#### System 4

In [None]:
tokenizer = AutoTokenizer.from_pretrained("t5-3b")
model_clf = AutoModelForSeq2SeqLM.from_pretrained("allenai/System4_classify_FigLang2022", torch_dtype=torch.float16).to(device)
model_exp = AutoModelForSeq2SeqLM.from_pretrained("allenai/System4_explain_FigLang2022", torch_dtype=torch.float16).to(device)

prems, hypos, fig_types, tgt_labels, tgt_explanations, pred_labels, pred_explanations = [], [], [], [], [], [], []

for sample in tqdm(test_data):
    premise, hypothesis, fig = sample['premise'], sample['hypothesis'], sample['fig_type']
    label, expl = sample['label'], sample['explanation']
    input_string = f"Premise: {premise} Hypothesis: {hypothesis} " 
    input_string += "Is there a contradiction or entailment between the premise and hypothesis? Answer : "
    #print(input_string)
    input_ids = tokenizer.encode(input_string, return_tensors="pt").to(device)
    output = model_clf.generate(input_ids, max_length=200)
    output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    pred_label = output.strip()
    input_string = f"Premise: {premise} Hypothesis: {hypothesis} " 
    input_string += f"Is there a contradiction or entailment between the premise and hypothesis? Answer : {pred_label}. Explanation : "
    #print(input_string)
    input_ids = tokenizer.encode(input_string, return_tensors="pt").to(device)
    output = model_exp.generate(input_ids, max_length=200)
    output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    pred_expl = output.strip()
    prems.append(premise)
    hypos.append(hypothesis)
    fig_types.append(fig)
    tgt_labels.append(label)
    tgt_explanations.append(expl)
    pred_labels.append(pred_label)
    pred_explanations.append(pred_expl)

In [None]:
assert len(prems) == len(hypos) == len(fig_types) == len(pred_labels) == len(tgt_labels) == len(pred_explanations) == len(tgt_explanations)
print(len(prems))

cols = ['premise', 'hypothesis', 'type', 'pred_label', 'ref_label', 'pred_explanation', 'ref_explanation']
df = pd.DataFrame(list(zip(prems, hypos, fig_types, pred_labels, tgt_labels, pred_explanations, tgt_explanations)), columns=cols)

path = './figurative_flute/data/outputs/dream/'
df.to_csv(f'{path}sys4_outputs.csv', header=True, index=False)

#### Ensemble

In [None]:
import os, pandas as pd

acc0_systems = [('sys1', 0.9465954606141522),('sys2', 0.9485981308411215),('sys3_emotion', 0.9392523364485982),('sys3_motivation', 0.9485981308411215),('sys3_consequence', 0.9459279038718291),('sys3_rot', 0.9238985313751669),('sys3_4dim', 0.9492656875834445),('sys4', 0.951935914552737)]
top5_acc0_systems = sorted(acc0_systems, key=lambda x: x[1], reverse=True)[:5]

dream_flute_outputs_path = './figurative_flute/data/outputs/dream/'
outputs_files = list(filter(lambda f: f.endswith('.csv'), os.listdir(dream_flute_outputs_path)))

df_tmp = pd.read_csv(dream_flute_outputs_path + outputs_files[0])
premises, hypotheses, fig_types, labels, explanations = df_tmp['premise'].tolist(), df_tmp['hypothesis'].tolist(), df_tmp['type'].tolist(), df_tmp['ref_label'].tolist(), df_tmp['ref_explanation'].tolist()

df_labels = pd.DataFrame()

for system, _ in top5_acc0_systems:
    for file in outputs_files:
        if system in file:
            sys_labels = pd.read_csv(dream_flute_outputs_path + file)['pred_label'].tolist()
            df_labels[system] = sys_labels

# majority voting
df_labels['ensemble_label'] = df_labels.mode(axis=1)[0]
ensemble_labels = df_labels['ensemble_label'].tolist()

ordered_systems = ['sys3_consequence', 'sys3_emotion', 'sys2', 'sys3_4dim', 'sys3_motivation', 'sys4', 'sys1']
ordered_systems_files = [f for system in ordered_systems for f in outputs_files if system in f]
ensemble_explanations = []

for sample_ix in range(len(df_tmp)):
    for system_file in ordered_systems_files:
        sys_lab, sys_expl = pd.read_csv(dream_flute_outputs_path + system_file).iloc[sample_ix][['pred_label', 'pred_explanation']]
        if sys_lab == ensemble_labels[sample_ix]:
            ensemble_explanations.append(sys_expl)
            break

df_out = pd.DataFrame(list(zip(premises, hypotheses, fig_types, ensemble_labels, labels, ensemble_explanations, explanations)), columns=['premise', 'hypothesis', 'type', 'pred_label', 'ref_label', 'pred_explanation', 'ref_explanation'])
df_out.to_csv(dream_flute_outputs_path + 'sys5_ensemble_outputs.csv', header=True, index=False)

#### DREAM

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch, json
from tqdm import tqdm

device = torch.device('cuda')

model = AutoModelForSeq2SeqLM.from_pretrained("allenai/DREAM", torch_dtype=torch.float16, device_map=device)
tokenizer = AutoTokenizer.from_pretrained("t5-11b")

In [None]:
train_data = json.load(open('./figurative_flute/data/train.json', 'r'))
train_data_dream = []

for sample in tqdm(train_data):
    premise, hypo = sample['premise'], sample['hypothesis']
    sample_dream = {**sample}
    for sent_type in ['premise', 'hypothesis']:
        for dream_type in ['motivation', 'emotion', 'rot', 'consequence']:
            input_string = f"$answer$ ; $question$ = [SITUATION] {premise if sent_type == 'premise' else hypo} [QUERY] {dream_type}"
            # print(input_string)
            input_ids = tokenizer.encode(input_string, return_tensors="pt").to(device)
            output = model.generate(input_ids, max_length=200)
            output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
            if dream_type != 'consequence':
                dream_out = output.split('$ =')[1].strip()
            else:
                if 'consequence]' in output:
                    dream_out = output.split('consequence]')[1].strip()
                else:
                    dream_out = output.split('$ =')[1].strip()
            sample_dream[f'{sent_type}_{dream_type}'] = dream_out
            # print(dream_out)
    train_data_dream.append(sample_dream)
    if len(train_data_dream) % 500 == 0:
        with open(f'./figurative_flute/data/train_dream_{len(train_data_dream)}.json', "w") as json_file:
            json.dump(train_data_dream, json_file, indent=4)
assert len(train_data) == len(train_data_dream)
    
with open('./figurative_flute/data/train_dream.json', "w") as json_file:
    json.dump(train_data_dream, json_file, indent=4)
    
###############

val_data = json.load(open('./figurative_flute/data/val.json', 'r'))
val_data_dream = []

for sample in tqdm(val_data):
    premise, hypo = sample['premise'], sample['hypothesis']
    sample_dream = {**sample}
    for sent_type in ['premise', 'hypothesis']:
        for dream_type in ['motivation', 'emotion', 'rot', 'consequence']:
            input_string = f"$answer$ ; $question$ = [SITUATION] {premise if sent_type == 'premise' else hypo} [QUERY] {dream_type}"
            # print(input_string)
            input_ids = tokenizer.encode(input_string, return_tensors="pt").to(device)
            output = model.generate(input_ids, max_length=200)
            output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
            if dream_type != 'consequence':
                dream_out = output.split('$ =')[1].strip()
            else:
                if 'consequence]' in output:
                    dream_out = output.split('consequence]')[1].strip()
                else:
                    dream_out = output.split('$ =')[1].strip()
            sample_dream[f'{sent_type}_{dream_type}'] = dream_out
            # print(dream_out)
    val_data_dream.append(sample_dream)
    if len(val_data_dream) % 250 == 0:
        with open(f'./figurative_flute/data/val_dream_{len(val_data_dream)}.json', "w") as json_file:
            json.dump(val_data_dream, json_file, indent=4)
assert len(val_data) == len(val_data_dream)
    
with open('./figurative_flute/data/val_dream.json', "w") as json_file:
    json.dump(val_data_dream, json_file, indent=4)
    
###############

test_data = json.load(open('./figurative_flute/data/test.json', 'r'))
test_data_dream = []

for sample in tqdm(test_data):
    premise, hypo = sample['premise'], sample['hypothesis']
    sample_dream = {**sample}
    for sent_type in ['premise', 'hypothesis']:
        for dream_type in ['motivation', 'emotion', 'rot', 'consequence']:
            input_string = f"$answer$ ; $question$ = [SITUATION] {premise if sent_type == 'premise' else hypo} [QUERY] {dream_type}"
            # print(input_string)
            input_ids = tokenizer.encode(input_string, return_tensors="pt").to(device)
            output = model.generate(input_ids, max_length=200)
            output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
            if dream_type != 'consequence':
                dream_out = output.split('$ =')[1].strip()
            else:
                if 'consequence]' in output:
                    dream_out = output.split('consequence]')[1].strip()
                else:
                    dream_out = output.split('$ =')[1].strip()
            sample_dream[f'{sent_type}_{dream_type}'] = dream_out
            # print(dream_out)
    test_data_dream.append(sample_dream)
    if len(test_data_dream) % 250 == 0:
        with open(f'./figurative_flute/data/test_dream_{len(test_data_dream)}.json', "w") as json_file:
            json.dump(test_data_dream, json_file, indent=4)
assert len(test_data) == len(test_data_dream)
    
with open('./figurative_flute/data/test_dream.json', "w") as json_file:
    json.dump(test_data_dream, json_file, indent=4)
    
###############