In [7]:
# Build Tag Vocab
from sklearn import metrics
import pandas as pd
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from sklearn import preprocessing
import numpy as np
from transformers import AutoTokenizer

input_train = "../../data/questions/Train_Questions54Top10.pkl"
train = pd.read_pickle(input_train)
train.tail()

Unnamed: 0,qid,title,desc_text,desc_code,creation_date,clean_tags
154729,68290150,Decoding JWT token on app route using express-...,I am using Keycloak for authentication and usi...,UnauthorizedError: error:0909006C:PEM routines...,2021-07-07T16:52:52.147,['node.js']
154730,68101848,spring security 5.2 multiple WebSecurityConfig...,I have multiple WebSecurityConfigurerAdapter's...,@Configuration @EnableWebSecurity @ComponentSc...,2021-06-23T14:31:21.877,['java']
154731,69003799,replace &lt;/p&gt from the dataframe column,I have a Data Frame column description. I want...,"Description ""&lt;p&gt;ID being used for RPA te...",2021-08-31T18:34:44.080,['python']
154732,68541127,Display uploaded images in the website - Django,I have been trying to display my images in my ...,from django.contrib import admin # Register yo...,2021-07-27T08:07:24.323,['python']
154733,68665335,Thymeleaf templates does not resolve property,I need to use Thymeleaf templates for email se...,"<p th:text=""#{TEST}""></p> Context context = ne...",2021-08-05T11:01:18.817,['java']


In [2]:
# -*- coding: utf-8 -*-
"""
-------------------------------------------------
   File Name：     question.py
   Description :
   Author :        
   date：          2021/11/7 11:02 AM
-------------------------------------------------
"""
from torch.utils.data import Dataset
import torch
import numpy as np
import ast


class QuestionDataset(Dataset):

    def __init__(self, df, mlb, tokenizer):
        self.title = df['title']
        self.text = df['desc_text']
        self.code = df['desc_code']
        self.targets = df['clean_tags']
        self.tokenizer = tokenizer
        self.mlb = mlb

    def __len__(self):
        return len(self.title)

    def __getitem__(self, index):
        title = str(self.title[index])
        text = str(self.text[index])
        code = str(self.code[index])

        tokens = title + " " + text + " " + code

        inputs = self.tokenizer(
            tokens,
            None,
            add_special_tokens=True,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        tags = self.targets[index]
        labels = set(ast.literal_eval(tags))
        ret = self.mlb.transform([labels])

        return {
            'input_ids': torch.tensor(inputs['input_ids'], dtype=torch.long),
            'targets': torch.from_numpy(ret[0]).type(torch.FloatTensor)
        }


In [5]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")

tab_vocab_path = "../../data/tags/20211110/top10/common.csv"
tag_vocab = pd.read_csv(tab_vocab_path)
tag_list = tag_vocab["tag"].astype(str).tolist()
mlb = preprocessing.MultiLabelBinarizer()
mlb.fit([tag_list])

training_set = QuestionDataset(train, mlb, tokenizer)

In [6]:
training_set[0]



{'input_ids': tensor([[    0,  9064,  7257,   146,   304, 13360,  2557,   923, 23431,    38,
            524,   667,     7,   278,  2557,   194,   716,    15,   414,    31,
          26504,     4,   152,   851,   162,    41,  5849,    35, 30001, 14943,
             22,  9064,  7257,     4,  3698, 13360,   113,    16,   373,  1881,
           2368,     4, 30001, 14943,    29,   531,    28,   373,    11,     5,
           6089,   276,   645,    11,   358,  7681, 19930,     4,   407,    38,
           1276,     7,  2935,   127,  3260,    19,   304, 47210,    35,   635,
             38,  1017,  1195,    45,   278,     5,   194,  3225,     7,  4276,
            114,    38,   218,    75,   240,     7,     4, 28013,    64,    38,
           6136,    24,    30,   442,     5,  2557,   923, 23431,   116,  3945,
             89,   143,   526,  3038,     7,    42,   116,  2612,    74,    42,
            173,     8,    45,   127,  1461,  2120,   116, 10759, 42693,   877,
          48377, 23687, 494