In [1]:
import os
import re
import time
import math
import torch
import random
import pickle
import os.path
import torchtext
import matplotlib
import numpy as np
import pandas as pd
import torch.nn as nn
import sentencepiece as spm


from glob import glob
from torch import optim
from itertools import cycle
from torchtext import vocab
from konlpy.tag import Mecab
from torchinfo import summary
from torch.nn import Transformer
from torch.utils.data import DataLoader
#from torchaudio.models import Conformer
from timeit import default_timer as timer
from soyspacing.countbase import CountSpace
from torch.nn.utils.rnn import pad_sequence
from torchmetrics.functional import word_error_rate
from nltk.translate.bleu_score import sentence_bleu
from torchtext.vocab import build_vocab_from_iterator

In [2]:
from typing import Optional, Tuple


__all__ = ["Conformer"]




class _ConvolutionModule(torch.nn.Module):
    r"""Conformer convolution module.

    Args:
        input_dim (int): input dimension.
        num_channels (int): number of depthwise convolution layer input channels.
        depthwise_kernel_size (int): kernel size of depthwise convolution layer.
        dropout (float, optional): dropout probability. (Default: 0.0)
        bias (bool, optional): indicates whether to add bias term to each convolution layer. (Default: ``False``)
        use_group_norm (bool, optional): use GroupNorm rather than BatchNorm. (Default: ``False``)
    """

    def __init__(
        self,
        input_dim: int,
        num_channels: int,
        depthwise_kernel_size: int,
        dropout: float = 0.0,
        bias: bool = False,
        use_group_norm: bool = False,
    ) -> None:
        super().__init__()
        if (depthwise_kernel_size - 1) % 2 != 0:
            raise ValueError("depthwise_kernel_size must be odd to achieve 'SAME' padding.")
        self.layer_norm = torch.nn.LayerNorm(input_dim)
        self.sequential = torch.nn.Sequential(
            torch.nn.Conv1d(
                input_dim,
                2 * num_channels,
                1,
                stride=1,
                padding=0,
                bias=bias,
            ),
            torch.nn.GLU(dim=1),
            torch.nn.Conv1d(
                num_channels,
                num_channels,
                depthwise_kernel_size,
                stride=1,
                padding=(depthwise_kernel_size - 1) // 2,
                groups=num_channels,
                bias=bias,
            ),
            torch.nn.GroupNorm(num_groups=1, num_channels=num_channels)
            if use_group_norm
            else torch.nn.BatchNorm1d(num_channels),
            torch.nn.SiLU(),
            torch.nn.Conv1d(
                num_channels,
                input_dim,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=bias,
            ),
            torch.nn.Dropout(dropout),
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        r"""
        Args:
            input (torch.Tensor): with shape `(B, T, D)`.

        Returns:
            torch.Tensor: output, with shape `(B, T, D)`.
        """
        x = self.layer_norm(input)
        x = x.transpose(1, 2)
        x = self.sequential(x)
        return x.transpose(1, 2)


class _FeedForwardModule(torch.nn.Module):
    r"""Positionwise feed forward layer.

    Args:
        input_dim (int): input dimension.
        hidden_dim (int): hidden dimension.
        dropout (float, optional): dropout probability. (Default: 0.0)
    """

    def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.0) -> None:
        super().__init__()
        self.sequential = torch.nn.Sequential(
            torch.nn.LayerNorm(input_dim),
            torch.nn.Linear(input_dim, hidden_dim, bias=True),
            torch.nn.SiLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_dim, input_dim, bias=True),
            torch.nn.Dropout(dropout),
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        r"""
        Args:
            input (torch.Tensor): with shape `(*, D)`.

        Returns:
            torch.Tensor: output, with shape `(*, D)`.
        """
        return self.sequential(input)


class ConformerLayer(torch.nn.Module):
    r"""Conformer layer that constitutes Conformer.

    Args:
        input_dim (int): input dimension.
        ffn_dim (int): hidden layer dimension of feedforward network.
        num_attention_heads (int): number of attention heads.
        depthwise_conv_kernel_size (int): kernel size of depthwise convolution layer.
        dropout (float, optional): dropout probability. (Default: 0.0)
        use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
            in the convolution module. (Default: ``False``)
        convolution_first (bool, optional): apply the convolution module ahead of
            the attention module. (Default: ``False``)
    """

    def __init__(
        self,
        input_dim: int,
        ffn_dim: int,
        num_attention_heads: int,
        depthwise_conv_kernel_size: int,
        dropout: float = 0.0,
        use_group_norm: bool = False,
        convolution_first: bool = False,
    ) -> None:
        super().__init__()

        self.ffn1 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout)

        self.self_attn_layer_norm = torch.nn.LayerNorm(input_dim)
        self.self_attn = torch.nn.MultiheadAttention(input_dim, num_attention_heads, dropout=dropout)
        self.self_attn_dropout = torch.nn.Dropout(dropout)

        self.conv_module = _ConvolutionModule(
            input_dim=input_dim,
            num_channels=input_dim,
            depthwise_kernel_size=depthwise_conv_kernel_size,
            dropout=dropout,
            bias=True,
            use_group_norm=use_group_norm,
        )

        self.ffn2 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
        self.final_layer_norm = torch.nn.LayerNorm(input_dim)
        self.convolution_first = convolution_first

    def _apply_convolution(self, input: torch.Tensor) -> torch.Tensor:
        residual = input
        input = input.transpose(0, 1)
        input = self.conv_module(input)
        input = input.transpose(0, 1)
        input = residual + input
        return input

    def forward(self, input: torch.Tensor, key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
        r"""
        Args:
            input (torch.Tensor): input, with shape `(T, B, D)`.
            key_padding_mask (torch.Tensor or None): key padding mask to use in self attention layer.

        Returns:
            torch.Tensor: output, with shape `(T, B, D)`.
        """
        residual = input
        x = self.ffn1(input)
        x = x * 0.5 + residual

        if self.convolution_first:
            x = self._apply_convolution(x)

        residual = x
        x = self.self_attn_layer_norm(x)
        x, _ = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=key_padding_mask,
            need_weights=False,
        )
        x = self.self_attn_dropout(x)
        x = x + residual

        if not self.convolution_first:
            x = self._apply_convolution(x)

        residual = x
        x = self.ffn2(x)
        x = x * 0.5 + residual

        x = self.final_layer_norm(x)
        return x


class Conformer(torch.nn.Module):
    r"""Conformer architecture introduced in
    *Conformer: Convolution-augmented Transformer for Speech Recognition*
    :cite:`gulati2020conformer`.

    Args:
        input_dim (int): input dimension.
        num_heads (int): number of attention heads in each Conformer layer.
        ffn_dim (int): hidden layer dimension of feedforward networks.
        num_layers (int): number of Conformer layers to instantiate.
        depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
        dropout (float, optional): dropout probability. (Default: 0.0)
        use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
            in the convolution module. (Default: ``False``)
        convolution_first (bool, optional): apply the convolution module ahead of
            the attention module. (Default: ``False``)

    Examples:
        >>> conformer = Conformer(
        >>>     input_dim=80,
        >>>     num_heads=4,
        >>>     ffn_dim=128,
        >>>     num_layers=4,
        >>>     depthwise_conv_kernel_size=31,
        >>> )
        >>> lengths = torch.randint(1, 400, (10,))  # (batch,)
        >>> input = torch.rand(10, int(lengths.max()), input_dim)  # (batch, num_frames, input_dim)
        >>> output = conformer(input, lengths)
    """

    def __init__(
        self,
        input_dim: int,
        num_heads: int,
        ffn_dim: int,
        num_layers: int,
        depthwise_conv_kernel_size: int,
        dropout: float = 0.0,
        use_group_norm: bool = False,
        convolution_first: bool = False,
    ):
        super().__init__()

        self.conformer_layers = torch.nn.ModuleList(
            [
                ConformerLayer(
                    input_dim,
                    ffn_dim,
                    num_heads,
                    depthwise_conv_kernel_size,
                    dropout=dropout,
                    use_group_norm=use_group_norm,
                    convolution_first=convolution_first,
                )
                for _ in range(num_layers)
            ]
        )
    def _lengths_to_padding_mask(self, lengths: torch.Tensor) -> torch.Tensor:
        batch_size = lengths.shape[0]
        max_length = int(torch.max(lengths).item())
        padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
            batch_size, max_length
        ) >= lengths.unsqueeze(1)
        return padding_mask
    

    def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""
        Args:
            input (torch.Tensor): with shape `(B, T, input_dim)`.
            lengths (torch.Tensor): with shape `(B,)` and i-th element representing
                number of valid frames for i-th batch element in ``input``.

        Returns:
            (torch.Tensor, torch.Tensor)
                torch.Tensor
                    output frames, with shape `(B, T, input_dim)`
                torch.Tensor
                    output lengths, with shape `(B,)` and i-th element representing
                    number of valid frames for i-th batch element in output frames.
        """
        encoder_padding_mask = self._lengths_to_padding_mask(lengths)

        x = input.transpose(0, 1)
        for layer in self.conformer_layers:
            x = layer(x, encoder_padding_mask)
        return x.transpose(0, 1), lengths, encoder_padding_mask 

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, emb_size, max_len, dropout ):
        super(PositionalEncoding, self).__init__()
        
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size) ## shape = (1,256)
        pos = torch.arange(0, max_len).reshape(max_len, 1) ## shape (maxlen, 1)
        
        pos_embedding = torch.zeros((max_len, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)
        
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding) ## opimizer가 업데이트 하지 않음,
        
    def forward(self, token_embedding):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])
    

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, emb_size):
        super(TokenEmbedding,self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx = 1) ## 단어사전 
        self.emb_size = emb_size
    
    def forward(self, tokens):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
    
    
class pre_TokenEmbedding(nn.Module): ## pretrain_word embedding
    def __init__(self, vocab, emb_size):
        super(pre_TokenEmbedding,self).__init__()
        self.embedding = nn.Embedding.from_pretrained(vocab.vectors, padding_idx=1, freeze= False)
        self.emb_size = emb_size
    def forward(self, tokens):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
    

class Seq2SeqTransformer(nn.Module):
    def __init__(self, 
                 num_encoder_layers,
                 num_decoder_layers,
                 emb_size, 
                 nhead,
                 src_vocab_size,
                 tgt_vocab_size,
                 max_len,
                 pre_vector,
                 dim_feedforward = 512, 
                 dropout = 0.4):
        super(Seq2SeqTransformer,self).__init__()
        
        self.pre_vector = pre_vector
        
        self.conformer = Conformer(input_dim = emb_size,
                                   num_heads=nhead,
                                   ffn_dim = dim_feedforward,
                                   num_layers = num_encoder_layers,
                                   depthwise_conv_kernel_size = 31, 
                                   dropout = dropout)
        
        self.decoder_layer = nn.TransformerDecoderLayer(d_model = emb_size,
                                                        dim_feedforward = dim_feedforward,
                                                        dropout= dropout,
                                                       nhead = nhead,
                                                       device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
        
        self.transformer_decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_decoder_layers)
        
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        

        self.src_tok_emb = pre_TokenEmbedding(pre_vector, emb_size)
        self.tgt_tok_emb = pre_TokenEmbedding(pre_vector, emb_size)

            
        self.positional_encoding = PositionalEncoding(emb_size, max_len, dropout = dropout)

    def forward(self, src, tgt, tgt_mask, tgt_padding_mask, lengths):

        src_emb = self.positional_encoding(self.src_tok_emb(src)) # seq, batch, dim
        src_emb = src_emb.permute(1,0,2) ## batch, seq, dim
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(tgt))
        memory = self.conformer(src_emb, lengths) ## batch, seq, dim
        encoder_padding_mask = memory[2]
        memory = memory[0].permute(1,0,2) ## seq, batch, dim
        outs = self.transformer_decoder(tgt_emb, memory, tgt_mask = tgt_mask, 
                                        tgt_key_padding_mask = tgt_padding_mask, 
                                        memory_key_padding_mask = encoder_padding_mask)
        return self.generator(outs)
        
    def encode(self, src, lengths):
        return self.conformer(self.positional_encoding(self.src_tok_emb(src)).permute(1,0,2), lengths)
    
    def decoder(self,tgt, memory, tgt_mask):
        return self.transformer_decoder(self.positional_encoding(self.tgt_tok_emb(tgt)),
                                       memory, tgt_mask)


In [61]:
class Denti_chatbot(nn.Module):
    
    def __init__(self):
        super(Denti_chatbot,self).__init__()
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.mecab = Mecab()
        
        self.soy_model_patient1 = CountSpace()
        self.soy_model_patient2 = CountSpace()
        self.soy_model_patient3 = CountSpace()
        self.soy_model_patient4 = CountSpace()
        self.soy_model_patient5 = CountSpace()
        
        self.soy_model_patient1.load_model('/data/kmkim/chatbot_data_model/dental_5/5_data/prepro/patient_1_patient_data_soy.model',
                                           json_format = False)
        self.soy_model_patient2.load_model('/data/kmkim/chatbot_data_model/dental_5/5_data/prepro/patient_2_patient_data_soy.model',
                                           json_format = False)
        self.soy_model_patient3.load_model('/data/kmkim/chatbot_data_model/dental_5/5_data/prepro/patient_3_patient_data_soy.model',
                                           json_format = False)
        self.soy_model_patient4.load_model('/data/kmkim/chatbot_data_model/dental_5/5_data/prepro/patient_4_patient_data_soy.model',
                                           json_format = False)
        self.soy_model_patient5.load_model('/data/kmkim/chatbot_data_model/dental_5/5_data/prepro/patient_5_patient_data_soy.model',
                                           json_format = False)


        self.pre_vocab, self.pre_vector = self.load_embedding()
        

        self.EMB_SIZE = 256
        self.NHEAD = 8
        self.FFN_HID_DIM = 512
        self.BATCH_SIZE = 128
        self.NUM_ENCODER_LAYERS = 4 ## 새로운 모델 레이어로 정정 필요
        self.NUM_DECODER_LAYERS = 4 ## 위와 같음
        self.max_length = 54
        
        
        self.transformer_p1 = Seq2SeqTransformer(self.NUM_ENCODER_LAYERS, self.NUM_DECODER_LAYERS, self.EMB_SIZE, self.NHEAD, len(self.pre_vocab['patient1']),
                                                 len(self.pre_vocab['patient1']),self.max_length, self.pre_vector['patient1'], self.FFN_HID_DIM)
        self.transformer_p1 = self.transformer_p1.to(self.device)
        self.transformer_p1.load_state_dict(torch.load('/data/kmkim/chatbot_data_model/dental_5/conformer_save/Conformer_patient1.pt'))
        
        self.transformer_p2 = Seq2SeqTransformer(self.NUM_ENCODER_LAYERS, self.NUM_DECODER_LAYERS, self.EMB_SIZE, self.NHEAD, len(self.pre_vocab['patient2']),
                                                 len(self.pre_vocab['patient2']),self.max_length, self.pre_vector['patient2'], self.FFN_HID_DIM)
        self.transformer_p2 = self.transformer_p2.to(self.device)
        self.transformer_p2.load_state_dict(torch.load('/data/kmkim/chatbot_data_model/dental_5/conformer_save/Conformer_patient2.pt'))
        
        self.transformer_p3 = Seq2SeqTransformer(self.NUM_ENCODER_LAYERS, self.NUM_DECODER_LAYERS, self.EMB_SIZE, self.NHEAD, len(self.pre_vocab['patient3']),
                                                 len(self.pre_vocab['patient3']),self.max_length, self.pre_vector['patient3'] ,self.FFN_HID_DIM)
        self.transformer_p3 = self.transformer_p3.to(self.device)
        self.transformer_p3.load_state_dict(torch.load('/data/kmkim/chatbot_data_model/dental_5/conformer_save/Conformer_patient3.pt'))
        
        self.transformer_p4 = Seq2SeqTransformer(self.NUM_ENCODER_LAYERS, self.NUM_DECODER_LAYERS, self.EMB_SIZE, self.NHEAD, len(self.pre_vocab['patient4']),
                                                 len(self.pre_vocab['patient4']),self.max_length, self.pre_vector['patient4'], self.FFN_HID_DIM)
        self.transformer_p4 = self.transformer_p4.to(self.device)
        self.transformer_p4.load_state_dict(torch.load('/data/kmkim/chatbot_data_model/dental_5/conformer_save/Conformer_patient4.pt'))        
        
        self.transformer_p5 = Seq2SeqTransformer(self.NUM_ENCODER_LAYERS, self.NUM_DECODER_LAYERS, self.EMB_SIZE, self.NHEAD, len(self.pre_vocab['patient5']),
                                                 len(self.pre_vocab['patient5']),self.max_length, self.pre_vector['patient5'], self.FFN_HID_DIM)
        self.transformer_p5 = self.transformer_p5.to(self.device)
        self.transformer_p5.load_state_dict(torch.load('/data/kmkim/chatbot_data_model/dental_5/conformer_save/Conformer_patient5.pt'))
        
 

        
        
    def tokenizer(self, sentence):
        return self.mecab.morphs(sentence)

    def sequential_transforms(self, *transforms):
        def func(txt_input):
            for transform in transforms:
                txt_input = transform(txt_input)
            return txt_input
        return func

    def tensor_transform(self,token_ids):
        return torch.cat((torch.tensor([2]),
                         torch.tensor(token_ids),
                         torch.tensor([3])))

    def load_embedding(self):

        os.chdir("/data/kmkim/chatbot_data_model/dental_5/5_data/prepro")
        
        patient1_que_emb= torchtext.vocab.Vectors('FastText_patient_1_patient_data') ## FastText_renew_chat_bot
        patient1_que_vocab = torchtext.vocab.vocab(patient1_que_emb.stoi, min_freq= 0, specials = ["<unk>", "<pad>", "<sos>", "<eos>"])

        patient2_que_emb= torchtext.vocab.Vectors('FastText_patient_2_patient_data')
        patient2_que_vocab = torchtext.vocab.vocab(patient2_que_emb.stoi, min_freq= 0, specials = ["<unk>", "<pad>", "<sos>", "<eos>"])

        patient3_que_emb= torchtext.vocab.Vectors('FastText_patient_3_patient_data')
        patient3_que_vocab = torchtext.vocab.vocab(patient3_que_emb.stoi, min_freq= 0, specials = ["<unk>", "<pad>", "<sos>", "<eos>"])

        patient4_que_emb= torchtext.vocab.Vectors('FastText_patient_4_patient_data')
        patient4_que_vocab = torchtext.vocab.vocab(patient4_que_emb.stoi, min_freq= 0, specials = ["<unk>", "<pad>", "<sos>", "<eos>"])
        
        patient5_que_emb= torchtext.vocab.Vectors('FastText_patient_5_patient_data')
        patient5_que_vocab = torchtext.vocab.vocab(patient5_que_emb.stoi, min_freq= 0, specials = ["<unk>", "<pad>", "<sos>", "<eos>"])


        #equ_vocab = torchtext.vocab.vocab(equ_emb.stoi, min_freq= 0)

        vector_dic = {'patient1': patient1_que_emb, 'patient2': patient2_que_emb, 'patient3': patient3_que_emb,
                      'patient4': patient4_que_emb, 'patient5': patient5_que_emb} 

        vocab_dic = {'patient1': patient1_que_vocab, 'patient2': patient2_que_vocab, 'patient3': patient3_que_vocab,
                      'patient4': patient4_que_vocab, 'patient5': patient5_que_vocab}

        for ln in ['patient1', 'patient2','patient3','patient4','patient5']:
            a = torch.zeros(4,256, requires_grad = False)                                 
            vector_dic[ln].vectors = torch.cat([a, vector_dic[ln].vectors], dim=0) 
            vector_dic[ln].stoi = dict(zip( vector_dic[ln].stoi.keys(), map(lambda x:x[1]+4,  vector_dic[ln].stoi.items()))) 
            for i, j in enumerate(["<unk>", "<pad>", "<sos>", "<eos>"]):
                 vector_dic[ln].stoi[j] = i                 
            vocab_dic[ln].set_default_index(0)                    

        return vocab_dic, vector_dic
    
    def greedy_decode(self, model, lengths, src, max_len, start_symbol):
        memory = model.encode(src, lengths)[0].permute(1,0,2)
        ys = torch.ones(1,src.size(1)).fill_(start_symbol).type(torch.long).to(self.device) 
        for i in range(max_len):
            tgt_mask = (self.generate_square_subsequent_mask(ys.size(0)))
            out = model.decoder(ys, memory, tgt_mask)
            out = out.transpose(0,1) # (sequence ,batch, hidden)  -> (batch,sequence ,hidden)
            prob = model.generator(out[:, -1]) ## 마지막 단어
            _, next_word = torch.max(prob, dim = 1)
            if next_word == 3:
                break
            next_word = next_word.unsqueeze(0)
            ys = torch.cat([ys, next_word], dim = 0)
        return ys
    
    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones((sz,sz), device = self.device ))== 1).T 
        mask = mask.float().masked_fill(mask==0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    

    def translate(self, model, src_sentence, pre_vocab, condition):
    
        model.eval()
        src = self.text_transform(src_sentence).view(-1, 1)
        src = src.to(self.device)
        
        num_tokens = src.shape[0]
        num_tokens = torch.tensor([num_tokens]).to(self.device)
        
        tgt_tokens = self.greedy_decode(
            model, num_tokens,  src, max_len=num_tokens + 4, start_symbol=2).flatten()
        
        ans = ("".join(pre_vocab.lookup_tokens(tgt_tokens.tolist()))).replace("<sos>", "").replace("<eos>","").replace("<pad>", "")
        if condition == 1:
            ans = self.soy_model_patient1.correct(ans, min_count = 1,force_abs_threshold=0.3, nonspace_threshold= -0.3, space_threshold= 0.3)[0]
        elif condition == 2:
            ans = self.soy_model_patient2.correct(ans, min_count = 1,force_abs_threshold=0.3, nonspace_threshold= -0.3, space_threshold= 0.3)[0]
        elif condition == 3:
            ans = self.soy_model_patient3.correct(ans, min_count = 1,force_abs_threshold=0.3, nonspace_threshold= -0.3, space_threshold= 0.3)[0]
        elif condition == 4:
            ans = self.soy_model_patient4.correct(ans, min_count = 1,force_abs_threshold=0.3, nonspace_threshold= -0.3, space_threshold= 0.3)[0] 
        else:
            ans = self.soy_model_patient5.correct(ans, min_count = 1,force_abs_threshold=0.3, nonspace_threshold= -0.3, space_threshold= 0.3)[0]
        return ans
        
    def forward(self, text, condition):
        if condition == 1:

            self.text_transform = self.sequential_transforms(self.tokenizer, self.pre_vocab['patient1'], self.tensor_transform)
            ans = self.translate(self.transformer_p1, text, self.pre_vocab['patient1'], condition)
        
        elif condition == 2:
            self.text_transform = self.sequential_transforms(self.tokenizer, self.pre_vocab['patient2'], self.tensor_transform)
            ans = self.translate(self.transformer_p2, text, self.pre_vocab['patient2'], condition)
         
        elif condition == 3:
            self.text_transform = self.sequential_transforms(self.tokenizer, self.pre_vocab['patient3'], self.tensor_transform)
            ans = self.translate(self.transformer_p3, text, self.pre_vocab['patient3'], condition)
        
        elif condition == 4:
            self.text_transform = self.sequential_transforms(self.tokenizer, self.pre_vocab['patient4'], self.tensor_transform)
            ans = self.translate(self.transformer_p4, text, self.pre_vocab['patient4'], condition)
            
        else:
            self.text_transform= self.sequential_transforms(self.tokenizer, self.pre_vocab['patient5'], self.tensor_transform)
            ans = self.translate(self.transformer_p5, text, self.pre_vocab['patient5'], condition)


        return ans
        