In [9]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
from transformers import Trainer, TrainingArguments
from datasets import load_metric
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import RobertaTokenizer, RobertaModel


In [6]:

class DocumentEncoder(nn.Module):
    def __init__(self, roberta_model_name, hidden_size):
        super(DocumentEncoder, self).__init__()
        self.roberta = RobertaModel.from_pretrained(roberta_model_name)
        self.tokenizer = RobertaTokenizer.from_pretrained(roberta_model_name)
        self.bilstm = nn.LSTM(self.roberta.config.hidden_size, 
                              hidden_size // 2,  # BiLSTM has half the hidden size per direction
                              num_layers=1,
                              bidirectional=True)

    def forward(self, input_text):
        # Tokenize input text
        tokens = self.tokenizer(input_text, return_tensors='pt', padding=True, truncation=True)
        input_ids = tokens.input_ids
        attention_mask = tokens.attention_mask
        
        # Pass input through RoBERTa
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_states = outputs.last_hidden_state
        
        # Pass RoBERTa outputs through BiLSTM
        lstm_output, _ = self.bilstm(last_hidden_states)
        
        return lstm_output, attention_mask


In [10]:

class SummaryDecoder(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(SummaryDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=1)
        self.attention_document = nn.Linear(hidden_size, hidden_size)
        self.u1 = nn.Parameter(torch.randn(hidden_size))
        self.Wout = nn.Linear(hidden_size * 2, output_size)

    def forward(self, input_token, hidden_state, document_output, document_mask):
        # LSTM step
        lstm_output, hidden_state = self.lstm(input_token.unsqueeze(0), hidden_state)

        # Attention over document
        document_scores = torch.tanh(self.attention_document(lstm_output) + document_output)
        document_scores = torch.matmul(document_scores, self.u1)
        document_scores = document_scores.masked_fill(~document_mask, float('-inf'))
        document_attention_weights = F.softmax(document_scores, dim=1)
        document_context_vector = torch.sum(document_attention_weights * document_output, dim=1)

        # Combine context vector with LSTM output
        combined_context = torch.cat((lstm_output.squeeze(0), document_context_vector), dim=1)
        output = self.Wout(combined_context)

        return output, hidden_state


In [None]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, document_text, target_summary=None, teacher_forcing_ratio=0.5):
        # Encode the document
        document_output, document_mask = self.encoder(document_text)
        document_output = document_output.to(self.device)
        document_mask = document_mask.to(self.device)

        # Prepare initial decoder input and hidden state
        batch_size = document_output.size(0)
        decoder_input = torch.zeros((batch_size, self.decoder.hidden_size)).to(self.device)
        hidden_state = None  # LSTM hidden state will be initialized to zero by default

        # Iterate over the target sequence
        outputs = []
        for t in range(target_summary.size(1)):
            output, hidden_state = self.decoder(decoder_input, hidden_state, document_output, document_mask)
            outputs.append(output.unsqueeze(1))

            # Teacher forcing: use actual target token or predicted token as next input
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            decoder_input = target_summary[:, t, :] if teacher_force else output

        outputs = torch.cat(outputs, dim=1)
        return outputs

    def generate_summary(self, document_text, max_length=50):
        self.eval()
        with torch.no_grad():
            # Encode the document
            document_output, document_mask = self.encoder(document_text)
            document_output = document_output.to(self.device)
            document_mask = document_mask.to(self.device)

            # Prepare initial decoder input and hidden state
            batch_size = document_output.size(0)
            decoder_input = torch.zeros((batch_size, self.decoder.hidden_size)).to(self.device)
            hidden_state = None  # LSTM hidden state will be initialized to zero by default

            # Generate summary tokens
            summary_tokens = []
            for _ in range(max_length):
                output, hidden_state = self.decoder(decoder_input, hidden_state, document_output, document_mask)
                summary_tokens.append(output.argmax(dim=1).unsqueeze(1))

                # Next input is the current output
                decoder_input = output

            summary_tokens = torch.cat(summary_tokens, dim=1)
        return summary_tokens


In [None]:
# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_size = 768
output_size = len(RobertaTokenizer.from_pretrained('roberta-base').vocab)

encoder = DocumentEncoder('roberta-base', hidden_size).to(device)
decoder = SummaryDecoder(hidden_size, output_size).to(device)
model = Seq2Seq(encoder, decoder, device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=RobertaTokenizer.from_pretrained('roberta-base').pad_token_id)

# Training loop (example, assuming `dataloader` provides batches of document texts and target summaries)
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for batch in dataloader:
        document_text, target_summary = batch
        document_text, target_summary = document_text.to(device), target_summary.to(device)

        optimizer.zero_grad()
        output = model(document_text, target_summary)
        loss = criterion(output.view(-1, output_size), target_summary.view(-1))
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Train Loss: {loss.item():.4f}')

# After training, generate a summary (example)
model.eval()
test_input_text = ["Your test input text here..."]
summary_tokens = model.generate_summary(test_input_text)
summary_text = RobertaTokenizer.from_pretrained('roberta-base').decode(summary_tokens[0], skip_special_tokens=True)
print(f'Generated Summary: {summary_text}')


# With Graph Embeddings

In [None]:

class SummaryDecoder(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(SummaryDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=1)
        self.attention_graph = nn.Linear(hidden_size, hidden_size)
        self.attention_document = nn.Linear(hidden_size, hidden_size)
        self.u0 = nn.Parameter(torch.randn(hidden_size))
        self.u1 = nn.Parameter(torch.randn(hidden_size))
        self.Wout = nn.Linear(hidden_size * 3, output_size)

    def forward(self, input_token, hidden_state, document_output, graph_output, document_mask, graph_mask):
        # LSTM step
        lstm_output, hidden_state = self.lstm(input_token.unsqueeze(0), hidden_state)

        # Attention over graph
        graph_scores = torch.tanh(self.attention_graph(lstm_output) + graph_output)
        graph_scores = torch.matmul(graph_scores, self.u0)
        graph_scores = graph_scores.masked_fill(~graph_mask, float('-inf'))
        graph_attention_weights = F.softmax(graph_scores, dim=1)
        graph_context_vector = torch.sum(graph_attention_weights * graph_output, dim=1)

        # Attention over document
        document_scores = torch.tanh(self.attention_document(lstm_output) + document_output + graph_context_vector.unsqueeze(1))
        document_scores = torch.matmul(document_scores, self.u1)
        document_scores = document_scores.masked_fill(~document_mask, float('-inf'))
        document_attention_weights = F.softmax(document_scores, dim=1)
        document_context_vector = torch.sum(document_attention_weights * document_output, dim=1)

        # Combine context vectors with LSTM output
        combined_context = torch.cat((lstm_output.squeeze(0), document_context_vector, graph_context_vector), dim=1)
        output = self.Wout(combined_context)

        return output, hidden_state

In [None]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, document_text, graph_nodes, graph_mask, target_summary=None, teacher_forcing_ratio=0.5):
        # Encode the document
        document_output, document_mask = self.encoder(document_text)
        document_output = document_output.to(self.device)
        document_mask = document_mask.to(self.device)

        # Encode the graph (assuming graph_nodes is already processed and is a tensor)
        graph_output = graph_nodes.to(self.device)
        graph_mask = graph_mask.to(self.device)

        # Prepare initial decoder input and hidden state
        batch_size = document_output.size(0)
        decoder_input = torch.zeros((batch_size, self.decoder.hidden_size)).to(self.device)
        hidden_state = None  # LSTM hidden state will be initialized to zero by default
                # Prepare initial decoder input and hidden state
        batch_size = document_output.size(0)
        decoder_input = torch.zeros((batch_size, self.decoder.hidden_size)).to(self.device)
        hidden_state = None  # LSTM hidden state will be initialized to zero by default

        # Iterate over the target sequence
        outputs = []
        for t in range(target_summary.size(1)):
            output, hidden_state = self.decoder(decoder_input, hidden_state, document_output, graph_output, document_mask, graph_mask)
            outputs.append(output.unsqueeze(1))

            # Teacher forcing: use actual target token or predicted token as next input
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            decoder_input = target_summary[:, t, :] if teacher_force else output

        outputs = torch.cat(outputs, dim=1)
        return outputs
    

In [None]:
def generate_summary(self, document_text, graph_nodes, graph_mask, max_length=50):
    self.eval()
    with torch.no_grad():
        # Encode the document
        document_output, document_mask = self.encoder(document_text)
        document_output = document_output.to(self.device)
        document_mask = document_mask.to(self.device)

        # Encode the graph (assuming graph_nodes is already processed and is a tensor)
        graph_output = graph_nodes.to(self.device)
        graph_mask = graph_mask.to(self.device)

        # Prepare initial decoder input and hidden state
        batch_size = document_output.size(0)
        decoder_input = torch.zeros((batch_size, self.decoder.hidden_size)).to(self.device)
        hidden_state = None  # LSTM hidden state will be initialized to zero by default

        # Generate summary tokens
        summary_tokens = []
        for _ in range(max_length):
            output, hidden_state = self.decoder(decoder_input, hidden_state, document_output, graph_output, document_mask, graph_mask)
            summary_tokens.append(output.argmax(dim=1).unsqueeze(1))

            # Next input is the current output
            decoder_input = output

        summary_tokens = torch.cat(summary_tokens, dim=1)
    return summary_tokens


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_size = 768
output_size = len(RobertaTokenizer.from_pretrained('roberta-base').vocab)

encoder = DocumentEncoder('roberta-base', hidden_size).to(device)
decoder = SummaryDecoder(hidden_size, output_size).to(device)
model = Seq2Seq(encoder, decoder, device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=RobertaTokenizer.from_pretrained('roberta-base').pad_token_id)

# Training loop (example, assuming `dataloader` provides batches of document texts, graph nodes, graph masks, and target summaries)
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for batch in dataloader:
        document_text, graph_nodes, graph_mask, target_summary = batch
        document_text, graph_nodes, graph_mask, target_summary = document_text.to(device), graph_nodes.to(device), graph_mask.to(device), target_summary.to(device)

        optimizer.zero_grad()
        output = model(document_text, graph_nodes, graph_mask, target_summary)
        loss = criterion(output.view(-1, output_size), target_summary.view(-1))
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Train Loss: {loss.item():.4f}')

# After training, generate a summary (example)
model.eval()
test_input_text = ["Your test input text here..."]
test_graph_nodes = torch.randn((1, 10, hidden_size))  # Example graph nodes
test_graph_mask = torch.ones((1, 10), dtype=torch.bool)  # Example graph mask
summary_tokens = model.generate_summary(test_input_text, test_graph_nodes, test_graph_mask)
summary_text = RobertaTokenizer.from_pretrained('roberta-base').decode(summary_tokens[0], skip_special_tokens=True)
print(f'Generated Summary: {summary_text}')