In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq
from datasets import Dataset
import pandas as pd
import numpy as np
import random
import torch
import types
import os
import re

In [None]:
!nvidia-smi

In [None]:
def get_folders_in_directory(directory_path):
    folders_list = [folder for folder in os.listdir(directory_path) if os.path.isdir(os.path.join(directory_path, folder))]
    return folders_list

new_directory_path = "/kaggle/tmp"
os.makedirs(new_directory_path)

directory_path = "/kaggle/"
folders = get_folders_in_directory(directory_path)

print(folders)

In [None]:
from IPython.display import display, HTML

display(HTML("<script>Jupyter.notebook.kernel.execute('config NotebookApp.iopub_msg_rate_limit=10000000000')</script>"))

In [None]:
aste_step1_aspect = """$T$ => aspects: [A]"""
aste_step1_opinion = """$T$ => opinions: [O]"""
aste_step1_sentiment = """$T$ => sentiments: [S]"""
aste_step2_aspect = [
    """$T$ $Q$ $A$: aspect, opinion, sentiment""",
    """$T$ $Q$ $A$: aspect, sentiment, opinion"""
]
aste_step2_opinion = [
    """$T$ $Q$ $O$: opinion, sentiment, aspect""",
    """$T$ $Q$ $O$: opinion, aspect, sentiment"""
]
aste_step2_sentiment = [
    """$T$ $Q$ $S$: sentiment, opinion, aspect""",
    """$T$ $Q$ $S$: sentiment, aspect, opinion"""
]

tasd_step1_aspect = """$T$ => aspects: [A]"""
tasd_step1_category = """$T$ => categories: [C]"""
tasd_step1_sentiment = """$T$ => sentiments: [S]"""
tasd_step2_aspect = [
    """$T$ $Q$ $A$: aspect, category, sentiment""",
    """$T$ $Q$ $A$: aspect, sentiment, category"""
]
tasd_step2_category = [
    """$T$ $Q$ $C$: category, sentiment, aspect""",
    """$T$ $Q$ $C$: category, aspect, sentiment"""
]
tasd_step2_sentiment = [
    """$T$ $Q$ $S$: sentiment, category, aspect""",
    """$T$ $Q$ $S$: sentiment, aspect, category"""
]

asqp_step1_aspect = """$T$ => aspects: [A]"""
asqp_step1_category = """$T$ => categories: [C]"""
asqp_step1_opinion = """$T$ => opinions: [O]"""
asqp_step1_sentiment = """$T$ => sentiments: [S]"""
asqp_step2_aspect = [
    """$T$ $Q$ $A$: aspect, category, opinion, sentiment""",
    """$T$ $Q$ $A$: aspect, category, sentiment, opinion""",
    """$T$ $Q$ $A$: aspect, opinion, sentiment, category""",
    """$T$ $Q$ $A$: aspect, opinion, category, sentiment """,
    """$T$ $Q$ $A$: aspect, sentiment, opinion, category""",
    """$T$ $Q$ $A$: aspect, sentiment, category, opinion"""
]
asqp_step2_category = [
    """$T$ $Q$ $C$: category, aspect, opinion, sentiment""",
    """$T$ $Q$ $C$: category, aspect, sentiment, opinion""",
    """$T$ $Q$ $C$: category, opinion, sentiment, aspect""",
    """$T$ $Q$ $C$: category, opinion, aspect, sentiment""",
    """$T$ $Q$ $C$: category, sentiment, opinion, aspect""",
    """$T$ $Q$ $C$: category, sentiment, aspect, opinion"""
]
asqp_step2_opinion = [
    """$T$ $Q$ $O$: opinion, category, aspect, sentiment""",
    """$T$ $Q$ $O$: opinion, category, sentiment, aspect""",
    """$T$ $Q$ $O$: opinion, aspect, sentiment, category""",
    """$T$ $Q$ $O$: opinion, aspect, category, sentiment""",
    """$T$ $Q$ $O$: opinion, sentiment, aspect, category""",
    """$T$ $Q$ $O$: opinion, sentiment, category, aspect"""
]
asqp_step2_sentiment = [
    """$T$ $Q$ $S$: sentiment, category, opinion, aspect""",
    """$T$ $Q$ $S$: sentiment, category, aspect, opinion""",
    """$T$ $Q$ $S$: sentiment, opinion, aspect, category""",
    """$T$ $Q$ $S$: sentiment, opinion, category, aspect""",
    """$T$ $Q$ $S$: sentiment, aspect, opinion, category""",
    """$T$ $Q$ $S$: sentiment, aspect, category, opinion"""
]

acos_step1_aspect = """$T$ => aspects: [A]"""
acos_step1_category = """$T$ => categories: [C]"""
acos_step1_opinion = """$T$ => opinions: [O]"""
acos_step1_sentiment = """$T$ => sentiments: [S]"""
acos_step2_aspect = [
    """$T$ $Q$ $A$: aspect, category, opinion, sentiment""",
    """$T$ $Q$ $A$: aspect, category, sentiment, opinion""",
    """$T$ $Q$ $A$: aspect, opinion, sentiment, category""",
    """$T$ $Q$ $A$: aspect, opinion, category, sentiment """,
    """$T$ $Q$ $A$: aspect, sentiment, opinion, category""",
    """$T$ $Q$ $A$: aspect, sentiment, category, opinion"""
]
acos_step2_category = [
    """$T$ $Q$ $C$: category, aspect, opinion, sentiment""",
    """$T$ $Q$ $C$: category, aspect, sentiment, opinion""",
    """$T$ $Q$ $C$: category, opinion, sentiment, aspect""",
    """$T$ $Q$ $C$: category, opinion, aspect, sentiment""",
    """$T$ $Q$ $C$: category, sentiment, opinion, aspect""",
    """$T$ $Q$ $C$: category, sentiment, aspect, opinion"""
]
acos_step2_opinion = [
    """$T$ $Q$ $O$: opinion, category, aspect, sentiment""",
    """$T$ $Q$ $O$: opinion, category, sentiment, aspect""",
    """$T$ $Q$ $O$: opinion, aspect, sentiment, category""",
    """$T$ $Q$ $O$: opinion, aspect, category, sentiment""",
    """$T$ $Q$ $O$: opinion, sentiment, aspect, category""",
    """$T$ $Q$ $O$: opinion, sentiment, category, aspect"""
]
acos_step2_sentiment = [
    """$T$ $Q$ $S$: sentiment, category, opinion, aspect""",
    """$T$ $Q$ $S$: sentiment, category, aspect, opinion""",
    """$T$ $Q$ $S$: sentiment, opinion, aspect, category""",
    """$T$ $Q$ $S$: sentiment, opinion, category, aspect""",
    """$T$ $Q$ $S$: sentiment, aspect, opinion, category""",
    """$T$ $Q$ $S$: sentiment, aspect, category, opinion"""
]

prompts = {
    "aste": {
        "step1": {
            "aspect": aste_step1_aspect,
            "opinion": aste_step1_opinion,
            "sentiment": aste_step1_sentiment
        },
        "step2": {
            "aspect": aste_step2_aspect,
            "opinion": aste_step2_opinion,
            "sentiment": aste_step2_sentiment,
        }
    },
    "tasd": {
        "step1": {
            "aspect": tasd_step1_aspect,
            "category": tasd_step1_category,
            "sentiment": tasd_step1_sentiment
        },
        "step2": {
            "aspect": tasd_step2_aspect,
            "category": tasd_step2_category,
            "sentiment": tasd_step2_sentiment,
        }
    },
    "asqp": {
        "step1": {
            "aspect": asqp_step1_aspect,
            "category": asqp_step1_category,
            "opinion": asqp_step1_opinion,
            "sentiment": asqp_step1_sentiment
        },
        "step2": {
            "aspect": asqp_step2_aspect,
            "category": asqp_step2_category,
            "opinion": asqp_step2_opinion,
            "sentiment": asqp_step2_sentiment,
        }
    },
    "acos": {
        "step1": {
            "aspect": acos_step1_aspect,
            "category": acos_step1_category,
            "opinion": acos_step1_opinion,
            "sentiment": acos_step1_sentiment
        },
        "step2": {
            "aspect": acos_step2_aspect,
            "category": acos_step2_category,
            "opinion": acos_step2_opinion,
            "sentiment": acos_step2_sentiment,
        }
    }
}

In [None]:
element_key_list = {"aste": ["aspect", "opinion", "sentiment"], "tasd": ["aspect", "category", "sentiment"], "asqp": ["aspect", "category", "opinion", "sentiment"], "acos": ["aspect", "category", "opinion", "sentiment"]}

def get_step1_prompts(task, text):
    final_prompts = [prompts[task]["step1"][element_key].replace("$T$", text) for element_key in element_key_list[task]]
    return final_prompts

def get_step2_prompts(status, task, text, aspects, categories=None, opinions=None, sentiments=None):
    new_sent = []
    target = []
    tuples = []
    element_list = element_key_list[task]
    for element in element_list:
        repeated_items = []
        if element == "category":
            element = "categorie"
        for item in locals()[f"{element}s"]:
            if item not in repeated_items:
                if locals()[f"{element}s"].count(item) == 1:
                    q = "->"
                elif locals()[f"{element}s"].count(item) >= 2:
                    q = "=>"
                    repeated_items.append(item)

                if element == "categorie":
                    element = "category"
                if element == "aspect":
                    mark = "A"
                elif element == "category":
                    mark = "C"
                elif element == "opinion":
                    mark = "O"
                elif element == "sentiment":
                    mark = "S"
                prompts_list = [prompt for prompt in prompts[task]["step2"][element]]
                for i in range(len(prompts_list)):
                    prompts_list[i] = prompts_list[i].replace(f"${mark}$", item)
                    prompts_list[i] = prompts_list[i].replace(f"$Q$", q)
                    prompts_list[i] = prompts_list[i].replace(f"$T$", text)
                new_sent.extend(prompts_list)
                if element == "category":
                    element = "categorie"
    if status == "train":
        for element in element_list:
            repeated_items = []
            if element == "category":
                element = "categorie"
            for item in locals()[f"{element}s"]:
                if item not in repeated_items:
                    index = locals()[f"{element}s"].index(item)
                    if task == "aste":
                        if element == "aspect":
                            if locals()[f"{element}s"].count(item) == 1:
                                tuples.append(f"{item}, {opinions[index]}, {sentiments[index]}")
                                tuples.append(f"{item}, {sentiments[index]}, {opinions[index]}")
                            else:
                                indexes = [index for index, value in enumerate(locals()[f"{element}s"]) if value == item]
                                _tuple1 = [f"{item}, {opinions[index]}, {sentiments[index]}" for index in indexes]
                                _tuple2 = [f"{item}, {sentiments[index]}, {opinions[index]}" for index in indexes]
                                tuples.append("; ".join(_tuple1))
                                tuples.append("; ".join(_tuple2))
                                repeated_items.append(item)
                        elif element == "opinion":
                            if locals()[f"{element}s"].count(item) == 1:
                                tuples.append(f"{item}, {sentiments[index]}, {aspects[index]}")
                                tuples.append(f"{item}, {aspects[index]}, {sentiments[index]}")
                            else:
                                indexes = [index for index, value in enumerate(locals()[f"{element}s"]) if value == item]
                                _tuple1 = [f"{item}, {sentiments[index]}, {aspects[index]}" for index in indexes]
                                _tuple2 = [f"{item}, {aspects[index]}, {sentiments[index]}" for index in indexes]
                                tuples.append("; ".join(_tuple1))
                                tuples.append("; ".join(_tuple2))
                                repeated_items.append(item)
                        elif element == "sentiment":
                            if locals()[f"{element}s"].count(item) == 1:
                                tuples.append(f"{item}, {opinions[index]}, {aspects[index]}")
                                tuples.append(f"{item}, {aspects[index]}, {opinions[index]}")
                            else:
                                indexes = [index for index, value in enumerate(locals()[f"{element}s"]) if value == item]
                                _tuple1 = [f"{item}, {opinions[index]}, {aspects[index]}" for index in indexes]
                                _tuple2 = [f"{item}, {aspects[index]}, {opinions[index]}" for index in indexes]
                                tuples.append("; ".join(_tuple1))
                                tuples.append("; ".join(_tuple2))
                                repeated_items.append(item)
                    elif task == "tasd":
                        if element == "aspect":
                            if locals()[f"{element}s"].count(item) == 1:
                                tuples.append(f"{item}, {categories[index]}, {sentiments[index]}")
                                tuples.append(f"{item}, {sentiments[index]}, {categories[index]}")
                            else:
                                indexes = [index for index, value in enumerate(locals()[f"{element}s"]) if value == item]
                                _tuple1 = [f"{item}, {categories[index]}, {sentiments[index]}" for index in indexes]
                                _tuple2 = [f"{item}, {sentiments[index]}, {categories[index]}" for index in indexes]
                                tuples.append("; ".join(_tuple1))
                                tuples.append("; ".join(_tuple2))
                                repeated_items.append(item)
                        elif element == "categorie":
                            if locals()[f"{element}s"].count(item) == 1:
                                tuples.append(f"{item}, {sentiments[index]}, {aspects[index]}")
                                tuples.append(f"{item}, {aspects[index]}, {sentiments[index]}")
                            else:
                                indexes = [index for index, value in enumerate(locals()[f"{element}s"]) if value == item]
                                _tuple1 = [f"{item}, {sentiments[index]}, {aspects[index]}" for index in indexes]
                                _tuple2 = [f"{item}, {aspects[index]}, {sentiments[index]}" for index in indexes]
                                tuples.append("; ".join(_tuple1))
                                tuples.append("; ".join(_tuple2))
                                repeated_items.append(item)
                        elif element == "sentiment":
                            if locals()[f"{element}s"].count(item) == 1:
                                tuples.append(f"{item}, {categories[index]}, {aspects[index]}")
                                tuples.append(f"{item}, {aspects[index]}, {categories[index]}")
                            else:
                                indexes = [index for index, value in enumerate(locals()[f"{element}s"]) if value == item]
                                _tuple1 = [f"{item}, {categories[index]}, {aspects[index]}" for index in indexes]
                                _tuple2 = [f"{item}, {aspects[index]}, {categories[index]}" for index in indexes]
                                tuples.append("; ".join(_tuple1))
                                tuples.append("; ".join(_tuple2))
                                repeated_items.append(item)
                    elif task == "asqp" or task == "acos":
                        if element == "aspect":
                            if locals()[f"{element}s"].count(item) == 1:
                                tuples.append(f"{item}, {categories[index]}, {opinions[index]}, {sentiments[index]}")
                                tuples.append(f"{item}, {categories[index]}, {sentiments[index]}, {opinions[index]}")
                                tuples.append(f"{item}, {opinions[index]}, {sentiments[index]}, {categories[index]}")
                                tuples.append(f"{item}, {opinions[index]}, {categories[index]}, {sentiments[index]}")
                                tuples.append(f"{item}, {sentiments[index]}, {opinions[index]}, {categories[index]}")
                                tuples.append(f"{item}, {sentiments[index]}, {categories[index]}, {opinions[index]}")
                            else:
                                indexes = [index for index, value in enumerate(locals()[f"{element}s"]) if value == item]
                                _tuple1 = [f"{item}, {categories[index]}, {opinions[index]}, {sentiments[index]}" for index in indexes]
                                _tuple2 = [f"{item}, {categories[index]}, {sentiments[index]}, {opinions[index]}" for index in indexes]
                                _tuple3 = [f"{item}, {opinions[index]}, {sentiments[index]}, {categories[index]}" for index in indexes]
                                _tuple4 = [f"{item}, {opinions[index]}, {categories[index]}, {sentiments[index]}" for index in indexes]
                                _tuple5 = [f"{item}, {sentiments[index]}, {opinions[index]}, {categories[index]}" for index in indexes]
                                _tuple6 = [f"{item}, {sentiments[index]}, {categories[index]}, {opinions[index]}" for index in indexes]
                                tuples.append("; ".join(_tuple1))
                                tuples.append("; ".join(_tuple2))
                                tuples.append("; ".join(_tuple3))
                                tuples.append("; ".join(_tuple4))
                                tuples.append("; ".join(_tuple5))
                                tuples.append("; ".join(_tuple6))
                                repeated_items.append(item)
                        elif element == "categorie":
                            if locals()[f"{element}s"].count(item) == 1:
                                tuples.append(f"{item}, {aspects[index]}, {opinions[index]}, {sentiments[index]}")
                                tuples.append(f"{item}, {aspects[index]}, {sentiments[index]}, {opinions[index]}")
                                tuples.append(f"{item}, {opinions[index]}, {sentiments[index]}, {aspects[index]}")
                                tuples.append(f"{item}, {opinions[index]}, {aspects[index]}, {sentiments[index]}")
                                tuples.append(f"{item}, {sentiments[index]}, {opinions[index]}, {aspects[index]}")
                                tuples.append(f"{item}, {sentiments[index]}, {aspects[index]}, {opinions[index]}")
                            else:
                                indexes = [index for index, value in enumerate(locals()[f"{element}s"]) if value == item]
                                _tuple1 = [f"{item}, {aspects[index]}, {opinions[index]}, {sentiments[index]}" for index in indexes]
                                _tuple2 = [f"{item}, {aspects[index]}, {sentiments[index]}, {opinions[index]}" for index in indexes]
                                _tuple3 = [f"{item}, {opinions[index]}, {sentiments[index]}, {aspects[index]}" for index in indexes]
                                _tuple4 = [f"{item}, {opinions[index]}, {aspects[index]}, {sentiments[index]}" for index in indexes]
                                _tuple5 = [f"{item}, {sentiments[index]}, {opinions[index]}, {aspects[index]}" for index in indexes]
                                _tuple6 = [f"{item}, {sentiments[index]}, {aspects[index]}, {opinions[index]}" for index in indexes]
                                tuples.append("; ".join(_tuple1))
                                tuples.append("; ".join(_tuple2))
                                tuples.append("; ".join(_tuple3))
                                tuples.append("; ".join(_tuple4))
                                tuples.append("; ".join(_tuple5))
                                tuples.append("; ".join(_tuple6))
                                repeated_items.append(item)
                        elif element == "opinion":
                            if locals()[f"{element}s"].count(item) == 1:
                                tuples.append(f"{item}, {categories[index]}, {aspects[index]}, {sentiments[index]}")
                                tuples.append(f"{item}, {categories[index]}, {sentiments[index]}, {aspects[index]}")
                                tuples.append(f"{item}, {aspects[index]}, {sentiments[index]}, {categories[index]}")
                                tuples.append(f"{item}, {aspects[index]}, {categories[index]}, {sentiments[index]}")
                                tuples.append(f"{item}, {sentiments[index]}, {aspects[index]}, {categories[index]}")
                                tuples.append(f"{item}, {sentiments[index]}, {categories[index]}, {aspects[index]}")
                            else:
                                indexes = [index for index, value in enumerate(locals()[f"{element}s"]) if value == item]
                                _tuple1 = [f"{item}, {categories[index]}, {aspects[index]}, {sentiments[index]}" for index in indexes]
                                _tuple2 = [f"{item}, {categories[index]}, {sentiments[index]}, {aspects[index]}" for index in indexes]
                                _tuple3 = [f"{item}, {aspects[index]}, {sentiments[index]}, {categories[index]}" for index in indexes]
                                _tuple4 = [f"{item}, {aspects[index]}, {categories[index]}, {sentiments[index]}" for index in indexes]
                                _tuple5 = [f"{item}, {sentiments[index]}, {aspects[index]}, {categories[index]}" for index in indexes]
                                _tuple6 = [f"{item}, {sentiments[index]}, {categories[index]}, {aspects[index]}" for index in indexes]
                                tuples.append("; ".join(_tuple1))
                                tuples.append("; ".join(_tuple2))
                                tuples.append("; ".join(_tuple3))
                                tuples.append("; ".join(_tuple4))
                                tuples.append("; ".join(_tuple5))
                                tuples.append("; ".join(_tuple6))
                                repeated_items.append(item)
                        elif element == "sentiment":
                            if locals()[f"{element}s"].count(item) == 1:
                                tuples.append(f"{item}, {categories[index]}, {opinions[index]}, {aspects[index]}")
                                tuples.append(f"{item}, {categories[index]}, {aspects[index]}, {opinions[index]}")
                                tuples.append(f"{item}, {opinions[index]}, {aspects[index]}, {categories[index]}")
                                tuples.append(f"{item}, {opinions[index]}, {categories[index]}, {aspects[index]}")
                                tuples.append(f"{item}, {aspects[index]}, {opinions[index]}, {categories[index]}")
                                tuples.append(f"{item}, {aspects[index]}, {categories[index]}, {opinions[index]}")
                            else:
                                indexes = [index for index, value in enumerate(locals()[f"{element}s"]) if value == item]
                                _tuple1 = [f"{item}, {categories[index]}, {opinions[index]}, {aspects[index]}" for index in indexes]
                                _tuple2 = [f"{item}, {categories[index]}, {aspects[index]}, {opinions[index]}" for index in indexes]
                                _tuple3 = [f"{item}, {opinions[index]}, {aspects[index]}, {categories[index]}" for index in indexes]
                                _tuple4 = [f"{item}, {opinions[index]}, {categories[index]}, {aspects[index]}" for index in indexes]
                                _tuple5 = [f"{item}, {aspects[index]}, {opinions[index]}, {categories[index]}" for index in indexes]
                                _tuple6 = [f"{item}, {aspects[index]}, {categories[index]}, {opinions[index]}" for index in indexes]
                                tuples.append("; ".join(_tuple1))
                                tuples.append("; ".join(_tuple2))
                                tuples.append("; ".join(_tuple3))
                                tuples.append("; ".join(_tuple4))
                                tuples.append("; ".join(_tuple5))
                                tuples.append("; ".join(_tuple6))
                                repeated_items.append(item)
        target.extend(tuples)
        return new_sent, target
    elif status == "inference":
        return new_sent

In [None]:
senttag2opinion = {'pos': 'positive', 'neg': 'negative', 'neu': 'neutral'}
sentword2opinion = {'positive': 'positive', 'negative': 'negative', 'neutral': 'neutral'}

rest_aspect_general_cate_list = [
    'restaurant', 'ambience', 'location', 'food', 'service', 'drinks'
]

rest_aspect_sub_cate_list = [
    'style_options', 'quality', 'prices', 'miscellaneous', 'general'
]

laptop_aspect_general_cate_list = [
    'display', 'out_of_scope', 'power_supply', 'warranty', 'keyboard', 'os', 'memory', 'multimedia_devices', 'laptop', 'shipping', 'graphics', 'battery', 'company', 'ports', 'motherboard', 'optical_drives', 'software', 'fans&cooling', 'hard_disc', 'cpu', 'hardware', 'mouse', 'support'
]

laptop_aspect_sub_cate_list = [
    'quality', 'general', 'price', 'design_features', 'miscellaneous', 'portability', 'connectivity', 'usability', 'operation_performance'
]

general_cate_list = {
    "rest14": rest_aspect_general_cate_list,
    "rest15": rest_aspect_general_cate_list,
    "rest": rest_aspect_general_cate_list,
    "rest16": rest_aspect_general_cate_list,
    "laptop": laptop_aspect_general_cate_list,
    "laptop14": laptop_aspect_general_cate_list
}

sub_cate_list = {
    "rest14": rest_aspect_sub_cate_list,
    "rest15": rest_aspect_sub_cate_list,
    "rest": rest_aspect_sub_cate_list,
    "rest16": rest_aspect_sub_cate_list,
    "laptop": laptop_aspect_sub_cate_list,
    "laptop14": laptop_aspect_sub_cate_list
}

In [None]:
main_prompt_marks = {
    "aste": '[A] [O] [S]',
    "tasd": '[A] [C] [S]',
    "asqp": '[A] [C] [O] [S]',
    "acos": '[A] [C] [O] [S]'
}

def get_task_prompt_marks(task):
    return main_prompt_marks[task]

def parse_aste_tuple(_tuple, sent):

    if isinstance(_tuple[0], str):
        res = _tuple

    elif isinstance(_tuple[0], list):
        start_idx = _tuple[0][0]
        end_idx = _tuple[0][-1] if len(_tuple[0]) > 1 else start_idx
        at = ' '.join(sent[start_idx:end_idx + 1])
        start_idx = _tuple[1][0]
        end_idx = _tuple[1][-1] if len(_tuple[1]) > 1 else start_idx
        ot = ' '.join(sent[start_idx:end_idx + 1])
        res = [at, ot, _tuple[2]]
    else:
        print(_tuple)
        raise NotImplementedError
    return res

def get_task_tuple(_tuple, task):
    if task == "aste":
        at, ot, sp = _tuple
        ac = None
    elif task == "tasd":
        at, ac, sp = _tuple
        ot = None
    elif task in ["asqp", "acos"]:
        at, ac, sp, ot = _tuple
    else:
        raise NotImplementedError

    if sp:
        sp = sentword2opinion[sp.lower()] if sp in sentword2opinion \
            else senttag2opinion[sp.lower()]
    if at and at.lower() == 'null':
        at = 'it'

    return at, ac, sp, ot

def generate_element_list(_tuple, marks, task, final_gold_output):
    at, ac, sp, ot = get_task_tuple(_tuple, task)
    element_dict = {"[A]": at, "[O]": ot, "[C]": ac, "[S]": sp}
    if final_gold_output == False:
        return [f"{element_dict[key]}" for key in marks.split(" ")]
    else:
        return [f"{key} {element_dict[key]}" for key in marks.split(" ")]

In [None]:
def get_para_targets(sents, labels, task, step, final_gold_output):
    if final_gold_output == False:
        targets = []
        new_sents = []
        marks = get_task_prompt_marks(task)
        for i in range(len(sents)):
            label = labels[i]
            cur_sent = sents[i]
            cur_sent_str = " ".join(cur_sent)
            if task == 'aste':
                assert len(label[0]) == 3
                parsed_label = []
                for _tuple in label:
                    parsed_tuple = parse_aste_tuple(_tuple, sents[i])
                    parsed_label.append(parsed_tuple)
                label = parsed_label
            label_pos = {}
            for _tuple in label:
                at, ac, sp, ot = get_task_tuple(_tuple, task)
                at_pos = cur_sent_str.find(at) if at else -1
                ot_pos = cur_sent_str.find(ot) if ot else -1
                last_pos = max(at_pos, ot_pos)
                last_pos = 1e4 if last_pos < 0 else last_pos
                label_pos[tuple(_tuple)] = last_pos
            new_label = [
                list(k)
                for k, _ in sorted(label_pos.items(), key=lambda x: x[1])
            ]
            label = new_label
            main_target = [generate_element_list(_tuple, marks, task, False) for _tuple in label]
            main_sent = " ".join(cur_sent)
            aspects, categories, opinions, sentiments = [], [], [], []
            target = []
            new_sent = []
            if task == "aste":
                for tar in main_target:
                    aspects.append(tar[0])
                    opinions.append(tar[1])
                    sentiments.append(tar[2])
                target.append("[A] "+"; ".join(aspects))
                target.append("[O] "+"; ".join(opinions))
                target.append("[S] "+"; ".join(sentiments))
            elif task == "tasd":
                for tar in main_target:
                    aspects.append(tar[0])
                    categories.append(tar[1])
                    sentiments.append(tar[2])
                target.append("[A] "+"; ".join(aspects))
                target.append("[C] "+"; ".join(categories))
                target.append("[S] "+"; ".join(sentiments))
            elif task == "asqp" or task == "acos":
                for tar in main_target:
                    aspects.append(tar[0])
                    categories.append(tar[1])
                    opinions.append(tar[2])
                    sentiments.append(tar[3])
                target.append("[A] "+"; ".join(aspects))
                target.append("[C] "+"; ".join(categories))
                target.append("[O] "+"; ".join(opinions))
                target.append("[S] "+"; ".join(sentiments))
            if step == 1:
                new_sent = get_step1_prompts(task, main_sent)
                targets.extend(target)
                new_sents.extend(new_sent)
            elif step == 2:
                target = []
                if task == "aste":
                    new_sent, target = get_step2_prompts("train", task, main_sent, aspects=aspects, opinions=opinions, sentiments=sentiments)
                elif task == "tasd":
                    new_sent, target = get_step2_prompts("train", task, main_sent, aspects=aspects, categories=categories, sentiments=sentiments)
                elif task == "asqp" or task == "acos":
                    new_sent, target = get_step2_prompts("train", task, main_sent, aspects=aspects, categories=categories, opinions=opinions, sentiments=sentiments)
                targets.extend(target)
                new_sents.extend(new_sent)
        return new_sents, targets
    else:
        targets = []
        marks = get_task_prompt_marks(task)
        for i in range(len(sents)):
            label = labels[i]
            cur_sent = sents[i]
            cur_sent_str = " ".join(cur_sent)
            if task == 'aste':
                assert len(label[0]) == 3
                parsed_label = []
                for _tuple in label:
                    parsed_tuple = parse_aste_tuple(_tuple, sents[i])
                    parsed_label.append(parsed_tuple)
                label = parsed_label
            label_pos = {}
            for _tuple in label:
                at, ac, sp, ot = get_task_tuple(_tuple, task)
                at_pos = cur_sent_str.find(at) if at else -1
                ot_pos = cur_sent_str.find(ot) if ot else -1
                last_pos = max(at_pos, ot_pos)
                last_pos = 1e4 if last_pos < 0 else last_pos
                label_pos[tuple(_tuple)] = last_pos
            new_label = [
                list(k)
                for k, _ in sorted(label_pos.items(), key=lambda x: x[1])
            ]
            label = new_label
            target = [
                " [SSEP] ".join([
                    " ".join(generate_element_list(_tuple, marks, task, True))
                    for _tuple in label
                ])
            ]
            targets.extend(target)
        return targets

In [None]:
def read_line_examples_from_file(data_path, task):
    tasks, data_names, sents, labels = [], [], [], []
    with open(data_path, 'r', encoding='UTF-8') as fp:
        if task != "unified":
            words, labels = [], []
            for line in fp:
                line = line.strip()
                line = line.lower()
                if line != '':
                    words, tuples = line.split('####')
                    sents.append(words.split())
                    labels.append(eval(tuples))
            return sents, labels
        else:
            for line in fp:
                line = line.strip()
                line = line.lower()
                if line != '':
                    parts = line.split('\t')
                    tasks.append(parts[0])
                    data_names.append(parts[1])
                    new_part = parts[2].split("####")
                    sents.append(new_part[0].strip().split())
                    labels.append(eval(new_part[1].strip()))
            return tasks, data_names, sents, labels


def get_transformed_io(data_path, task, step, final_gold_output=False, seed=None, percentage=None):
    sents, labels = read_line_examples_from_file(data_path, task)
    inputs = [s.copy() for s in sents]
    if seed is not None and percentage is not None:
        random.seed(seed)
        num_sample = int(len(inputs) * percentage)
        sample_indices = random.sample(list(range(0, len(inputs))), num_sample)
        inputs = [inputs[i] for i in sample_indices]
        labels = [labels[i] for i in sample_indices]
    if final_gold_output==False:
        new_inputs, targets = get_para_targets(inputs, labels, task, step, final_gold_output)
        return new_inputs, targets
    else:
        targets = get_para_targets(inputs, labels, task, step, final_gold_output)
        return targets

In [None]:
parent_directory = "/kaggle/input/e2tp-absa/"
single_tasks = {"tasd": ["rest15", "rest16"], "asqp": ["rest15", "rest16"], "acos": ["laptop16", "rest16"], "aste": ["laptop14", "rest14", "rest15", "rest16", "laptop14-rest14", "laptop14-rest15", "laptop14-rest16", "rest14-laptop14", "rest15-laptop14", "rest16-laptop14"]}

data_step1 = {}
for task, data_names in single_tasks.items():
    for data_name in data_names:
        train_inputs, train_targets = get_transformed_io(f"{parent_directory}{task}/{data_name}/train.txt", task, step=1, final_gold_output=False)
        dev_inputs, dev_targets = get_transformed_io(f"{parent_directory}{task}/{data_name}/dev.txt", task, step=1, final_gold_output=False)
        test_inputs, test_targets = get_transformed_io(f"{parent_directory}{task}/{data_name}/test.txt", task, step=1, final_gold_output=False)
        data_step1[f"{task}-{data_name}"] = [[train_inputs, train_targets], [dev_inputs, dev_targets], [test_inputs, test_targets]]

data_step2 = {}
for task, data_names in single_tasks.items():
    for data_name in data_names:
        train_inputs, train_targets = get_transformed_io(f"{parent_directory}{task}/{data_name}/train.txt", task, step=2, final_gold_output=False)
        dev_inputs, dev_targets = get_transformed_io(f"{parent_directory}{task}/{data_name}/dev.txt", task, step=2, final_gold_output=False)
        test_targets = get_transformed_io(f"{parent_directory}{task}/{data_name}/test.txt", task, step=2, final_gold_output=True)
        data_step2[f"{task}-{data_name}"] = [[train_inputs, train_targets], [dev_inputs, dev_targets], [test_targets]]

In [None]:
def extract_spans_para(seq, seq_type):
    quads = []
    sents = [s.strip() for s in seq.split('[SSEP]')]
    for s in sents:
        try:
            tok_list = ["[C]", "[S]", "[A]", "[O]"]

            for tok in tok_list:
                if tok not in s:
                    s += " {} null".format(tok)
            index_ac = s.index("[C]")
            index_sp = s.index("[S]")
            index_at = s.index("[A]")
            index_ot = s.index("[O]")

            combined_list = [index_ac, index_sp, index_at, index_ot]
            arg_index_list = list(np.argsort(combined_list))

            result = []
            for i in range(len(combined_list)):
                start = combined_list[i] + 4
                sort_index = arg_index_list.index(i)
                if sort_index < 3:
                    next_ = arg_index_list[sort_index + 1]
                    re = s[start:combined_list[next_]]
                else:
                    re = s[start:]
                result.append(re.strip())

            ac, sp, at, ot = result

            if at.lower() == 'it':
                at = 'null'
        except ValueError:
            try:
                print(f'In {seq_type} seq, cannot decode: {s}')
                pass
            except UnicodeEncodeError:
                print(f'In {seq_type} seq, a string cannot be decoded')
                pass
            ac, at, sp, ot = '', '', '', ''

        quads.append((ac, at, sp, ot))

    return quads


def compute_f1_scores(pred_pt, gold_pt, verbose=True):
    """
    Function to compute F1 scores with pred and gold quads
    The input needs to be already processed
    """
    # number of true postive, gold standard, predictions
    n_tp, n_gold, n_pred = 0, 0, 0

    for i in range(len(pred_pt)):
        n_gold += len(gold_pt[i])
        n_pred += len(pred_pt[i])

        for t in pred_pt[i]:
            if t in gold_pt[i]:
                n_tp += 1
    precision = float(n_tp) / float(n_pred) if n_pred != 0 else 0
    recall = float(n_tp) / float(n_gold) if n_gold != 0 else 0
    f1 = 2 * precision * recall / (
        precision + recall) if precision != 0 or recall != 0 else 0
    scores = {
        'precision': precision * 100,
        'recall': recall * 100,
        'f1': f1 * 100
    }

    return scores


def compute_scores(pred_seqs, gold_seqs, verbose=True):
    """
    Compute model performance
    """
    assert len(pred_seqs) == len(gold_seqs), (len(pred_seqs), len(gold_seqs))
    num_samples = len(gold_seqs)

    all_labels, all_preds = [], []

    for i in range(num_samples):
        gold_list = extract_spans_para(gold_seqs[i], 'gold')
        pred_list = extract_spans_para(pred_seqs[i], 'pred')
        all_labels.append(gold_list)
        all_preds.append(pred_list)

    scores = compute_f1_scores(all_preds, all_labels)

    return scores, all_labels, all_preds

In [None]:
def extract_text(input_string):
    return input_string.split("=>")[0].strip()

def find_current_order(input_prompt):
    if input_prompt.find(": aspect, opinion, sentiment") != -1 and "sentiment," not in input_prompt:
        return ["[A]", "[O]", "[S]"]
    elif input_prompt.find(": aspect, sentiment, opinion") != -1 and "opinion," not in input_prompt:
        return ["[A]", "[S]", "[O]"]
    elif input_prompt.find(": opinion, sentiment, aspect") != -1 and "aspect," not in input_prompt:
        return ["[O]", "[S]", "[A]"]
    elif input_prompt.find(": opinion, aspect, sentiment") != -1 and "sentiment," not in input_prompt:
        return ["[O]", "[A]", "[S]"]
    elif input_prompt.find(": sentiment, opinion, aspect") != -1 and "aspect," not in input_prompt:
        return ["[S]", "[O]", "[A]"]
    elif input_prompt.find(": sentiment, aspect, opinion") != -1 and "opinion," not in input_prompt:
        return ["[S]", "[A]", "[O]"]

    elif input_prompt.find(": aspect, category, sentiment") != -1 and "sentiment," not in input_prompt:
        return ["[A]", "[C]", "[S]"]
    elif input_prompt.find(": aspect, sentiment, category") != -1 and "category," not in input_prompt:
        return ["[A]", "[S]", "[C]"]
    elif input_prompt.find(": category, sentiment, aspect") != -1 and "aspect," not in input_prompt:
        return ["[C]", "[S]", "[A]"]
    elif input_prompt.find(": category, aspect, sentiment") != -1 and "sentiment," not in input_prompt:
        return ["[C]", "[A]", "[S]"]
    elif input_prompt.find(": sentiment, category, aspect") != -1 and "aspect," not in input_prompt:
        return ["[S]", "[C]", "[A]"]
    elif input_prompt.find(": sentiment, aspect, category") != -1 and "category," not in input_prompt:
        return ["[S]", "[A]", "[C]"]

    elif input_prompt.find(": aspect, category, opinion, sentiment") != -1:
        return ["[A]", "[C]", "[O]", "[S]"]
    elif input_prompt.find(": aspect, category, sentiment, opinion") != -1:
        return ["[A]", "[C]", "[S]", "[O]"]
    elif input_prompt.find(": aspect, opinion, sentiment, category") != -1:
        return ["[A]", "[O]", "[S]", "[C]"]
    elif input_prompt.find(": aspect, opinion, category, sentiment") != -1:
        return ["[A]", "[O]", "[C]", "[S]"]
    elif input_prompt.find(": aspect, sentiment, opinion, category") != -1:
        return ["[A]", "[S]", "[O]", "[C]"]
    elif input_prompt.find(": aspect, sentiment, category, opinion") != -1:
        return ["[A]", "[S]", "[C]", "[O]"]

    elif input_prompt.find("category, aspect, opinion, sentiment") != -1:
        return ["[C]", "[A]", "[O]", "[S]"]
    elif input_prompt.find("category, aspect, sentiment, opinion") != -1:
        return ["[C]", "[A]", "[S]", "[O]"]
    elif input_prompt.find("category, opinion, sentiment, aspect") != -1:
        return ["[C]", "[O]", "[S]", "[A]"]
    elif input_prompt.find("category, opinion, aspect, sentiment") != -1:
        return ["[C]", "[O]", "[A]", "[S]"]
    elif input_prompt.find("category, sentiment, opinion, aspect") != -1:
        return ["[C]", "[S]", "[O]", "[A]"]
    elif input_prompt.find("category, sentiment, aspect, opinion") != -1:
        return ["[C]", "[S]", "[A]", "[O]"]

    elif input_prompt.find("opinion, category, aspect, sentiment") != -1:
        return ["[O]", "[C]", "[A]", "[S]"]
    elif input_prompt.find("opinion, category, sentiment, aspect") != -1:
        return ["[O]", "[C]", "[S]", "[A]"]
    elif input_prompt.find("opinion, aspect, sentiment, category") != -1:
        return ["[O]", "[A]", "[S]", "[C]"]
    elif input_prompt.find("opinion, aspect, category, sentiment") != -1:
        return ["[O]", "[A]", "[C]", "[S]"]
    elif input_prompt.find("opinion, sentiment, aspect, category") != -1:
        return ["[O]", "[S]", "[A]", "[C]"]
    elif input_prompt.find("opinion, sentiment, category, aspect") != -1:
        return ["[O]", "[S]", "[C]", "[A]"]

    elif input_prompt.find("sentiment, category, opinion, aspect") != -1:
        return ["[S]", "[C]", "[O]", "[A]"]
    elif input_prompt.find("sentiment, category, aspect, opinion") != -1:
        return ["[S]", "[C]", "[A]", "[O]"]
    elif input_prompt.find("sentiment, opinion, aspect, category") != -1:
        return ["[S]", "[O]", "[A]", "[C]"]
    elif input_prompt.find("sentiment, opinion, category, aspect") != -1:
        return ["[S]", "[O]", "[C]", "[A]"]
    elif input_prompt.find("sentiment, aspect, opinion, category") != -1:
        return ["[S]", "[A]", "[O]", "[C]"]
    elif input_prompt.find("sentiment, aspect, category, opinion") != -1:
        return ["[S]", "[A]", "[C]", "[O]"]
    else:
        raise Exception("!!")

def extract_elements(text, char, remove_par=True):
    if remove_par == True:
        text = text.replace("(", "")
        text = text.replace(")", "")
    texts = text.split(char)
    return [text.strip() for text in texts if text.strip() != ""]

def set_default_order(current_order, goal_order, _tuple):
    if len(extract_elements(_tuple.replace(";", ","), ",")) == len(current_order) * len(extract_elements(_tuple, ";")):
        if _tuple.find(";") != -1:
            _tuple = extract_elements(_tuple, ";", remove_par = False)
            for i in range(len(_tuple)):
                _tuple[i] = extract_elements(_tuple[i], ",")
                for j in range(len(_tuple[i])):
                    _tuple[i][j] = [current_order[j], _tuple[i][j]]
                index_dict = {value: index for index, value in enumerate(goal_order)}
                _tuple[i] = sorted(_tuple[i], key=lambda x: index_dict[x[0]])
                for j in range(len(_tuple[i])):
                    _tuple[i][j] = " ".join(_tuple[i][j])
                _tuple[i] = " ".join(_tuple[i])
            return " [SSEP] ".join(_tuple)
        else:
            _tuple = extract_elements(_tuple, ",")
            for j in range(len(_tuple)):
                _tuple[j] = [current_order[j], _tuple[j]]
            index_dict = {value: index for index, value in enumerate(goal_order)}
            _tuple = sorted(_tuple, key=lambda x: index_dict[x[0]])
            for j in range(len(_tuple)):
                _tuple[j] = " ".join(_tuple[j])
            _tuple = " ".join(_tuple)
            return _tuple
    else:
        flag = 0
        _lis = _tuple.split("; ")
        for i in range(len(_lis)):
            if len(extract_elements(_lis[i].replace(";", ","), ",")) == len(current_order) * len(extract_elements(_lis[i], ";")) != "":
                _lis[i] = set_default_order(current_order, goal_order, _lis[i])
                flag = 1
        if flag == 1:
            return "; ".join(_lis)
        print(_tuple)
        return ""

def inference(model_path_step1, model_path_step2, task_name, data_name, full_or_diet, depth):
    global data_step1
    global data_step2
    outputs_step1, outputs_step2 = [], []
    prompt_quantities_step2 = []
    device1 = "cuda:0" if torch.cuda.is_available() else "cpu"
    device2 = "cuda:1" if torch.cuda.is_available() else "cpu"
    tokenizer_step1 = AutoTokenizer.from_pretrained(model_path_step1)
    model_step1 = AutoModelForSeq2SeqLM.from_pretrained(model_path_step1).to(device1)
    tokenizer_step2 = AutoTokenizer.from_pretrained(model_path_step2)
    model_step2 = AutoModelForSeq2SeqLM.from_pretrained(model_path_step2).to(device2)
    batch_size = 100
    prompt_batches_step1 = [data_step1[f"{task_name}-{data_name}"][2][0][i:i+batch_size] for i in range(0, len(data_step1[f"{task_name}-{data_name}"][2][0]), batch_size)]
    torch.cuda.empty_cache()
    for batch in prompt_batches_step1:
        inputs = tokenizer_step1(batch, return_tensors="pt", padding=True, truncation=True, max_length=200)
        inputs = {key: tensor.to(device1) for key, tensor in inputs.items()}
        outputs = model_step1.generate(**inputs, max_new_tokens=256)
        batch_responses = [tokenizer_step1.decode(output, skip_special_tokens=True) for output in outputs]
        outputs_step1.extend(batch_responses)
    texts = []
    prompts_step1 = [data for data in data_step1[f"{task_name}-{data_name}"][2][0]]
    prompts_step2 = []
    for prompt in prompts_step1:
        texts.append(extract_text(prompt))
    if task_name == "aste":
        i = 0
        while i < len(outputs_step1):
            aspects, opinions, sentiments = [], [], []
            outputs_step1[i] = outputs_step1[i].replace("[A] ", "")
            for item in outputs_step1[i].split("; "):
                aspects.append(item)
            outputs_step1[i+1] = outputs_step1[i+1].replace("[O] ", "")
            for item in outputs_step1[i+1].split("; "):
                opinions.append(item)
            outputs_step1[i+2] = outputs_step1[i+2].replace("[S] ", "")
            for item in outputs_step1[i+2].split("; "):
                sentiments.append(item)
            prompt_list = get_step2_prompts("inference", task_name, texts[i], aspects = aspects, opinions = opinions, sentiments = sentiments)
            prompts_step2.extend(prompt_list)
            prompt_quantities_step2.append(len(prompt_list))
            i += 3
    elif task_name == "tasd":
        i = 0
        while i < len(outputs_step1):
            aspects, categories, sentiments = [], [], []
            outputs_step1[i] = outputs_step1[i].replace("[A] ", "")
            for item in outputs_step1[i].split("; "):
                aspects.append(item)
            outputs_step1[i+1] = outputs_step1[i+1].replace("[C] ", "")
            for item in outputs_step1[i+1].split("; "):
                categories.append(item)
            outputs_step1[i+2] = outputs_step1[i+2].replace("[S] ", "")
            for item in outputs_step1[i+2].split("; "):
                sentiments.append(item)
            prompt_list = get_step2_prompts("inference", task_name, texts[i], aspects = aspects, categories = categories, sentiments = sentiments)
            prompts_step2.extend(prompt_list)
            prompt_quantities_step2.append(len(prompt_list))
            i += 3
    elif task_name == "asqp" or task_name == "acos":
        i = 0
        while i < len(outputs_step1):
            aspects, categories, opinions, sentiments = [], [], [], []
            outputs_step1[i] = outputs_step1[i].replace("[A] ", "")
            for item in outputs_step1[i].split("; "):
                aspects.append(item)
            outputs_step1[i+1] = outputs_step1[i+1].replace("[C] ", "")
            for item in outputs_step1[i+1].split("; "):
                categories.append(item)
            outputs_step1[i+2] = outputs_step1[i+2].replace("[O] ", "")
            for item in outputs_step1[i+2].split("; "):
                opinions.append(item)
            outputs_step1[i+3] = outputs_step1[i+3].replace("[S] ", "")
            for item in outputs_step1[i+3].split("; "):
                sentiments.append(item)
            prompt_list = get_step2_prompts("inference", task_name, texts[i], aspects = aspects, categories = categories, opinions = opinions, sentiments = sentiments)
            prompts_step2.extend(prompt_list)
            prompt_quantities_step2.append(len(prompt_list))
            i += 4
    batch_size = 200
    prompt_batches_step2 = [prompts_step2[i:i+batch_size] for i in range(0, len(prompts_step2), batch_size)]
    torch.cuda.empty_cache()
    for batch in prompt_batches_step2:
        inputs = tokenizer_step2(batch, return_tensors="pt", padding=True, truncation=True, max_length=250)
        inputs = {key: tensor.to(device2) for key, tensor in inputs.items()}
        outputs = model_step2.generate(**inputs, max_new_tokens=512)
        batch_responses = [tokenizer_step2.decode(output, skip_special_tokens=True) for output in outputs]
        outputs_step2.extend(batch_responses)
        
    for i in range(len(outputs_step2)):
        new_list = outputs_step2[i].split("; ")
        for j in range(len(new_list)):
            new_list[j] = "("+new_list[j]+")"
        outputs_step2[i] = "; ".join(new_list)
        
    tuples = []
    current_orders_step2 = []
    current_position = 0

    for quantity in prompt_quantities_step2:
        _tuple = []
        current_order = []
        for i in range(quantity):
            _tuple.append(outputs_step2[current_position])
            current_order.append(find_current_order(prompts_step2[current_position]))
            current_position += 1
        tuples.append(_tuple)
        current_orders_step2.append(current_order)

    if task_name == "aste":
        goal_order = ["[A]", "[O]", "[S]"]
        threshold = 3
        if full_or_diet == "diet":
            threshold = 1
    elif task_name == "tasd":
        goal_order = ["[A]", "[C]", "[S]"]
        threshold = 3
        if full_or_diet == "diet":
            threshold = 1
    elif task_name == "asqp" or task_name == "acos":
        goal_order = ["[A]", "[C]", "[O]", "[S]"]
        threshold = 12
        if full_or_diet == "diet":
            threshold = 2
    for i in range(len(tuples)):
        for j in range(len(tuples[i])):
            current_order = current_orders_step2[i][j]
            tuples[i][j] = set_default_order(current_order, goal_order, tuples[i][j])

    for i in range(len(tuples)):
        aux_list = []
        for item in tuples[i]:
            if item.find("[SSEP]") != -1:
                aux_list.extend(item.split(" [SSEP] "))
            elif item.strip() != "":
                aux_list.append(item)

        frequency_dict = {}
        for item in aux_list:
            if item not in frequency_dict.keys():
                frequency_dict[item] = 1
            else:
                frequency_dict[item] += 1
        tuples[i] = " [SSEP] ".join([key for key, value in frequency_dict.items() if value > threshold])
        c = 1
        while c <= depth:
            if tuples[i] == "":
                tuples[i] = " [SSEP] ".join([key for key, value in frequency_dict.items() if value > threshold-c])
            c += 1
    golds = data_step2[f"{task_name}-{data_name}"][2][0]
    scores, all_labels, all_preds = compute_scores(tuples, golds)
    print(scores)

In [None]:
status = input("Choose the Status: ")
if status == "train":
    model_path = input("Model Path: ")
    train_task_name = input("Training Task Name: ")
    train_data_name = input("Training Data Name: ")
    step = int(input("Step: "))
    model_seed = int(input("Trainer Seed (0 or 42): "))
    if step == 1:
        full_or_diet_or_low = "full"
    elif step == 2:
        full_or_diet_or_low = input("Full Data or Diet or Low Resource: ")
        if full_or_diet_or_low == "diet":
            random.seed(model_seed)

            for element in element_key_list[train_task_name]:
                prompts[train_task_name]['step2'][element] = [random.choice(prompts[train_task_name]['step2'][element])]

            train_inputs, train_targets = get_transformed_io(f"{parent_directory}{train_task_name}/{train_data_name}/train.txt", train_task_name, step=2, final_gold_output=False)
            dev_inputs, dev_targets = get_transformed_io(f"{parent_directory}{train_task_name}/{train_data_name}/dev.txt", train_task_name, step=2, final_gold_output=False)
            test_targets = get_transformed_io(f"{parent_directory}{train_task_name}/{train_data_name}/test.txt", train_task_name, step=2, final_gold_output=True)
            train_targets = [data_step2[f"{train_task_name}-{train_data_name}"][0][1][data_step2[f"{train_task_name}-{train_data_name}"][0][0].index(train_input)] for train_input in train_inputs]
            dev_targets = [data_step2[f"{train_task_name}-{train_data_name}"][1][1][data_step2[f"{train_task_name}-{train_data_name}"][1][0].index(dev_input)] for dev_input in dev_inputs]
            data_step2[f"{train_task_name}-{train_data_name}"] = [[train_inputs, train_targets], [dev_inputs, dev_targets], [test_targets]]

        elif full_or_diet_or_low == "low":
            percentage = int(input("Percentage (1 or 2 or 5 or 10): "))
            data_seed = input("Data Seed (5 or 10 or 20 or 25): ")
            model_seed = data_seed
    epoch = int(input("Epoch: "))
    huggingface_token = input("Huggingface Token: ")
    repo_name = input("Repository Name: ")
elif status == "inference":
    model_path_step1 = input("First Step Model Path in Huggingface: ")
    model_path_step2 = input("Second Step Model Path in Huggingface: ")
    inference_task_name = input("Inference Task Name: ")
    inference_data_name = input("Inference Data Name: ")
    full_or_diet = input("Full Data or Diet: ")
    depth = int(input("Depth for Empty Tuples: "))
    if full_or_diet == "diet":
        model_seed = int(input("Trainer Seed (0 or 42): "))
        random.seed(model_seed)

        for element in element_key_list[inference_task_name]:
            prompts[inference_task_name]['step2'][element] = [random.choice(prompts[inference_task_name]['step2'][element])]

        train_inputs, train_targets = get_transformed_io(f"{parent_directory}{inference_task_name}/{inference_data_name}/train.txt", inference_task_name, step=2, final_gold_output=False)
        dev_inputs, dev_targets = get_transformed_io(f"{parent_directory}{inference_task_name}/{inference_data_name}/dev.txt", inference_task_name, step=2, final_gold_output=False)
        test_targets = get_transformed_io(f"{parent_directory}{inference_task_name}/{inference_data_name}/test.txt", inference_task_name, step=2, final_gold_output=True)
        data_step2[f"{inference_task_name}-{inference_data_name}"] = [[train_inputs, train_targets], [dev_inputs, dev_targets], [test_targets]]

    inference(model_path_step1, model_path_step2, inference_task_name, inference_data_name, full_or_diet, depth)

In [None]:
if full_or_diet_or_low == "low":
    random.seed(data_seed)
    
    data_step1 = {}
    train_inputs, train_targets = get_transformed_io(f"{parent_directory}{train_task_name}/{train_data_name}/train.txt", train_task_name, step=1, final_gold_output=False, seed=data_seed, percentage=percentage/100)
    dev_inputs, dev_targets = get_transformed_io(f"{parent_directory}{train_task_name}/{train_data_name}/dev.txt", train_task_name, step=1, final_gold_output=False, seed=data_seed, percentage=percentage/100)
    test_inputs, test_targets = get_transformed_io(f"{parent_directory}{train_task_name}/{train_data_name}/test.txt", train_task_name, step=1, final_gold_output=False)
    data_step1[f"{train_task_name}-{train_data_name}"] = [[train_inputs, train_targets], [dev_inputs, dev_targets], [test_inputs, test_targets]]

    data_step2 = {}
    train_inputs, train_targets = get_transformed_io(f"{parent_directory}{train_task_name}/{train_data_name}/train.txt", train_task_name, step=2, final_gold_output=False, seed=data_seed, percentage=percentage/100)
    dev_inputs, dev_targets = get_transformed_io(f"{parent_directory}{train_task_name}/{train_data_name}/dev.txt", train_task_name, step=2, final_gold_output=False, seed=data_seed, percentage=percentage/100)
    test_targets = get_transformed_io(f"{parent_directory}{train_task_name}/{train_data_name}/test.txt", train_task_name, step=2, final_gold_output=True)
    data_step2[f"{train_task_name}-{train_data_name}"] = [[train_inputs, train_targets], [dev_inputs, dev_targets], [test_targets]] 

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device)

whole_related_data = globals()[f"data_step{step}"][f"{train_task_name}-{train_data_name}"]
input_max_length = max([len(tokenizer.tokenize(data)) for data in (whole_related_data[0][0]+whole_related_data[1][0])])
target_max_length = max([len(tokenizer.tokenize(data)) for data in (whole_related_data[0][1]+whole_related_data[1][1])])

def init_args(epoch):
    args_step1 = types.SimpleNamespace(
        learning_rate=3e-4,
        train_batch_size=16,
        eval_batch_size=8,
        output_dir='/kaggle/tmp',
        num_train_epochs=epoch, # Default (full) 15
    )
    args_step2 = types.SimpleNamespace(
        learning_rate=1e-4, # For ASTE (R16) and TASD (R16) full, learning_rate = 2e-4, epoch = 15
        train_batch_size=16,
        eval_batch_size=8,
        output_dir='/kaggle/tmp',
        num_train_epochs=epoch, # Default (full) 20
    )
    return args_step1, args_step2

args_step1, args_step2 = init_args(epoch)
if step == 1:
    args = args_step1
elif step == 2:
    args = args_step2
print(input_max_length)
print(target_max_length)

In [None]:
train_df = pd.DataFrame({"input": globals()[f"data_step{step}"][f"{train_task_name}-{train_data_name}"][0][0], "target": globals()[f"data_step{step}"][f"{train_task_name}-{train_data_name}"][0][1]})
val_df = pd.DataFrame({"input": globals()[f"data_step{step}"][f"{train_task_name}-{train_data_name}"][1][0], "target": globals()[f"data_step{step}"][f"{train_task_name}-{train_data_name}"][1][1]})
train_data = Dataset.from_dict({"input": train_df["input"], "target": train_df["target"]})
val_data = Dataset.from_dict({"input": val_df["input"], "target": val_df["target"]})

In [None]:
def convert_examples_to_features(example_batch):
    global input_max_length
    global target_max_length
    input_texts = example_batch["input"]
    target_texts = example_batch["target"]

    input_encodings = tokenizer(input_texts, padding="max_length", truncation=True, max_length=input_max_length)
    target_encodings = tokenizer(target_texts, padding="max_length", truncation=True, max_length=target_max_length)

    return {
        'input_ids': input_encodings['input_ids'],
        'attention_mask': input_encodings['attention_mask'],
        'labels': target_encodings['input_ids']
    }

train_pt = train_data.map(convert_examples_to_features, batched=True)
val_pt = val_data.map(convert_examples_to_features, batched=True)

seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer_args = TrainingArguments(
    output_dir = args.output_dir,
    num_train_epochs = args.num_train_epochs,
    learning_rate = args.learning_rate,
    per_device_train_batch_size = args.train_batch_size,
    per_device_eval_batch_size = args.eval_batch_size,
    evaluation_strategy = "epoch",
    logging_strategy = "epoch",
    fp16 = True,
    seed = model_seed
)

trainer = Trainer(
    model=model,
    args=trainer_args,
    tokenizer=tokenizer,
    data_collator=seq2seq_data_collator,
    train_dataset=train_pt,
    eval_dataset=val_pt
)

torch.cuda.empty_cache()

trainer.train()

trainer.save_model("/kaggle/working/")

tokenizer.save_pretrained("/kaggle/working/")

In [None]:
!huggingface-cli login --token {huggingface_token}
!huggingface-cli repo create {repo_name} --type model -y
finetuned_model = AutoModelForSeq2SeqLM.from_pretrained('/kaggle/working/')
finetuned_model.push_to_hub(repo_name)
tokenizer.push_to_hub(repo_name)