#### load data

In [None]:
import transformers
import torch
from transformers import BertModel, BertTokenizerFast

In [None]:
weight = 'bert-base-uncased'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
max_len = 35

In [None]:
import json

train_path = r'D:\LeStoreDownload\AKE-copy01\datas\Election-Trec\train.json'
# train_path = '../datas/daily life/train.json'
test_path = r'D:\LeStoreDownload\AKE-copy01\datas\Election-Trec\test.json'
# test_path = '../datas/daily life/test.json'

train_file = json.load(open(train_path,'r',encoding='utf-8'))
test_file = json.load(open(test_path, 'r', encoding='utf-8'))

In [None]:
# Append all words, eye-tracking signals, EEG signals and tags from training json to list
train_sens, train_tags = [],[]
train_Feature = []
train_word_nums = []

sens = ''
nums = 0
for key in train_file.keys():
    tags = []
    features = []
    items = train_file[key]
    sens = ''
    nums = 0
    for item in items:
        sens += item[0]
        sens += ' '
        features.append(item[1:-1])               # ET+EEG: [1: -1]
        tags.append(item[-1])
        nums += 1
    train_sens.append(sens.strip())
    train_word_nums.append(nums)
    train_Feature.append(features)
    train_tags.append(tags)

In [None]:
# Append all words, eye-tracking signals, EEG signals and tags from testing json to list
test_sens, test_tags = [],[]
test_Feature = []
test_word_nums = []

sens = ''
nums = 0
for key in test_file.keys():
    tags = []
    features = []
    items = test_file[key]
    sens = ''
    nums = 0
    for item in items:
        sens += item[0]
        sens += ' '
        features.append(item[1:-1])                # ET+EEG: [1: -1]
        tags.append(item[-1])
        nums += 1
    test_sens.append(sens.strip())
    test_word_nums.append(nums)
    test_Feature.append(features)
    test_tags.append(tags)

In [None]:
len(test_sens)

#### build dataset

In [None]:
from torch.utils.data import Dataset, DataLoader

In [None]:
tokenizer = BertTokenizerFast.from_pretrained(weight)

In [None]:
label_to_ids = {'none': 0, 'B': 1, 'I': 2, 'E': 3, 'S': 4, "O": 5}
# label_to_ids = {'O': 0, 'B': 1, 'I': 2, 'E': 3, 'S': 4}

In [None]:
from tqdm import tqdm
import numpy as np

class MyDataset(Dataset):
    def __init__(self, texts, old_features, tags):
        self.texts = texts
        self.tags = tags
        self.old_features = old_features
        
        self.labels = []
        self.tokens = []
        self.features = []
        
        self.input_ids = None
        self.attention_masks = None

    def encode(self):
        for i in tqdm(range(len(self.texts))):
          text = self.texts[i]
          tag = self.tags[i]
          feature = self.old_features[i]
          tags, tokens, features = align_label(text, tag, feature)
          self.labels.append(tags)
          self.tokens.append(tokens)
          self.features.append(features)
          
        self.features = np.array(self.features,float)
        self.inputs = tokenizer(self.texts, max_length=max_len, add_special_tokens=True, padding='max_length', truncation=True, return_tensors='pt')
        self.input_ids = self.inputs['input_ids']
        self.attention_masks = self.inputs['attention_mask']

    def __getitem__(self, idx):
        return self.input_ids[idx,:], self.attention_masks[idx,:], self.tokens[idx], torch.tensor(self.features[idx],dtype=torch.float32), torch.tensor(self.labels[idx])

    def __len__(self):
        return len(self.input_ids)

In [None]:
label_all_tokens = True



def align_label(text, labels, features):
    input = tokenizer(text, max_length=max_len, add_special_tokens=True, padding='max_length', truncation=True, return_tensors='pt')
    word_ids = input.word_ids()
    input_ids = input['input_ids']
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

    previous_word_idx = None
    new_labels, new_features = [], []
    no_features = [0 for _ in range(1, 26)]

    for word_idx in word_ids:
        if word_idx is None:
            new_labels.append('none')
            new_features.append(no_features)
        elif word_idx != previous_word_idx:
            try:
                new_labels.append(labels[word_idx])
                new_features.append(features[word_idx])
            except:
                new_labels.append('none')
                new_features.append(no_features)
        else:
            try:
                new_labels.append(labels[word_idx] if label_all_tokens else 'none')
                new_features.append(features[word_idx] if label_all_tokens else no_features)
            except:
                new_labels.append('none')
                new_features.append(no_features)
        previous_word_idx = word_idx

    label_ids = [label_to_ids[label] for label in new_labels]
    return label_ids, tokens, new_features


In [None]:
train_dataset = MyDataset(train_sens, train_Feature, train_tags)
train_dataset.encode()

In [None]:
test_dataset = MyDataset(test_sens, test_Feature, test_tags)
test_dataset.encode()

In [None]:
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=128)
test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=128)

#### construct bert  model

In [None]:
import torch.nn as nn
import torch



class BertNerModel(nn.Module):
    def __init__(self, num_labels):
        super(BertNerModel, self).__init__()

        self.bert = BertModel.from_pretrained(weight)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(768 + 17, num_labels)  # 修改为17维

    def forward(self, input_ids, attention_mask, extra_features, token_type_ids=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        pooled_output = outputs[0]
        bert_outputs = self.dropout(pooled_output)

        # 只使用前17个维度作为眼动信号特征
        eye_tracking_features = extra_features[:, :, :17]  # 取前17维度的特征
        outputs = torch.concat((bert_outputs, eye_tracking_features), -1)

        outputs = self.classifier(outputs)

        return outputs



#### evaluate

In [None]:
def TagConvert(raw_tags, words_set, poss=None):
    true_tags = []
    for i in range(raw_tags.shape[0]):
      kw_list = []
      nkw_list = ""
      for j in range(len(raw_tags[i])):
          item = raw_tags[i][j]
          if item == 0:
              continue
          if poss !=None and j in poss[i]:
              continue
          # if item == 5:
          #     continue
          if item == 4:
              kw_list.append(str(words_set[j][i]))
          if item == 1:
              nkw_list += str(words_set[j][i])
          if item == 2:
              nkw_list += " "
              nkw_list += str(words_set[j][i])
          if item == 3:
              nkw_list += " "
              nkw_list += str(words_set[j][i])
              kw_list.append(nkw_list)
              nkw_list = ""

      true_tags.append(kw_list)
    return true_tags

In [None]:
def evaluate(predict_data, target_data, topk=3):
  TRUE_COUNT, PRED_COUNT, GOLD_COUNT = 0.0, 0.0, 0.0
  for index, words in enumerate(predict_data):
      y_pred, y_true = None, target_data[index]

      if type(predict_data) == str:
          words = sorted(words.items(), key=lambda item: (-item[1], item[0]))
          y_pred = [i[0] for i in words]
      elif type(predict_data) == list:
          y_pred = words

      y_pred = y_pred[0: topk]
      TRUE_NUM = len(set(y_pred) & set(y_true))
      TRUE_COUNT += TRUE_NUM
      PRED_COUNT += len(y_pred)
      GOLD_COUNT += len(y_true)
  # compute P
  if PRED_COUNT != 0:
      p = (TRUE_COUNT / PRED_COUNT)
  else:
      p = 0
  # compute R
  if GOLD_COUNT != 0:
      r = (TRUE_COUNT / GOLD_COUNT)
  else:
      r = 0
  # compute F1
  if (r + p) != 0:
      f1 = ((2 * r * p) / (r + p))
  else:
      f1 = 0

  p = round(p * 100, 2)
  r = round(r * 100, 2)
  f1 = round(f1 * 100, 2)

  return p, r, f1

In [None]:
import numpy as np

def calculate_f1(y_pred, y_true):
    # flatten and convert to numpy array
    y_true = y_true.view(-1)
    y_pred = y_pred.view(-1)
    y_true = y_true.detach().cpu().numpy()
    y_pred = y_pred.detach().cpu().numpy()

    mask = np.where(y_true != 0)

    y_true = y_true[mask]
    y_pred = y_pred[mask]

    return y_pred, y_true

#### start training

In [None]:
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, AdamW

model = BertNerModel(num_labels=6)
model = model.to(device)

optim = AdamW(model.parameters(),lr=5e-5,weight_decay=1e-2)
loss_fn = CrossEntropyLoss(reduction='none', ignore_index=0)
loss_fn = loss_fn.to(device)

In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import f1_score

epochs = 5
best_f1 = 0.0
for epoch in tqdm(range(epochs)):
    loss_value = 0.0
    model.train()
    label_true, label_pred = [], []
    for i,batch in enumerate(train_dataloader):
        optim.zero_grad()
        input_ids, attention_masks, _, features, tags = batch
        pred_tags = model(input_ids.to(device), attention_masks.to(device), features.to(device))

        loss = loss_fn(pred_tags.permute(0,2,1),tags.to(device))
        loss = loss.mean()
        loss.backward()
        optim.step()

        pred_tags = F.softmax(pred_tags,dim=-1)
        pred_tags = torch.argmax(pred_tags,dim=-1)

        y_pred, y_true = calculate_f1(pred_tags, tags)
        label_true.extend(y_true)
        label_pred.extend(y_pred)
    
        loss_value += loss.item()

    label_train_f1 = f1_score(label_true, label_pred, average='macro')

    model.eval()
    kw_true, kw_pred = [], []
    label_true, label_pred = [],[]
    for i,batch in enumerate(test_dataloader):
      input_ids, attention_masks, tokens, features, tags = batch
      with torch.no_grad():
          for module in model.modules():
              if isinstance(module, nn.Dropout):
                  module.p = 0
                  module.train(False)
          pred_tags = model(input_ids.to(device), attention_masks.to(device), features.to(device))
          pred_tags = F.softmax(pred_tags,dim=-1)
          pred_tags = torch.argmax(pred_tags,dim=-1)

      y_pred, y_true = calculate_f1(pred_tags, tags)
      label_true.extend(y_true)
      label_pred.extend(y_pred)

      # more balance evaluate
      poss = []
      for i in range(len(tags)):
          pos = []
          for j in range(len(tags[i])):
              if tags[i][j] == 0:
                  pos.append(j)
          poss.append(pos)
           
      kw_true.extend(TagConvert(tags,tokens))
      kw_pred.extend(TagConvert(pred_tags,tokens,poss))

    label_f1 = f1_score(label_true, label_pred, average='macro')
    P, R, F1 = evaluate(kw_true, kw_pred)
    
    if F1 > best_f1:
        best_f1 = F1
        torch.save(model.state_dict(),'./pretrain_pt/bert_ET.pt')
        
    print("epoch{}:  loss:{:.2f}   train_f1_value:{:.2f}  test_f1_value:{:.2f}  kw_f1_value:{:.2f}".format(
        epoch+1, loss_value / len(train_dataloader), label_train_f1, label_f1, F1
    ))

#### inference

In [None]:
model = BertNerModel(num_labels=6)
model.load_state_dict(torch.load('./pretrain_pt/bert_ET.pt'))
model = model.to(device)

In [None]:
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import f1_score

model.eval()
kw_true, kw_pred = [], []
label_true, label_pred = [],[]
for i,batch in enumerate(test_dataloader):
    input_ids, attention_masks, tokens, extra_features, tags = batch
    with torch.no_grad():
        for module in model.modules():
            if isinstance(module, nn.Dropout):
                module.p = 0
                module.train(False)
        #pred_tags = model(input_ids.to(device), attention_masks.to(device))
        pred_tags = model(input_ids.to(device), attention_masks.to(device), extra_features.to(device))
        pred_tags = F.softmax(pred_tags,dim=-1)
        pred_tags = torch.argmax(pred_tags,dim=-1)

    y_pred, y_true = calculate_f1(pred_tags, tags)
    label_true.extend(y_true)
    label_pred.extend(y_pred)

    # more balance evaluate
    poss = []
    for i in range(len(tags)):
        pos = []
        for j in range(len(tags[i])):
            if tags[i][j] == 0:
                pos.append(j)
        poss.append(pos)
        
    kw_true.extend(TagConvert(tags,tokens))
    kw_pred.extend(TagConvert(pred_tags,tokens,poss))

label_f1 = f1_score(label_true, label_pred, average='macro')
P, R, F1 = evaluate(kw_true, kw_pred)

In [None]:
print(P)
print(R)
print(F1)

In [None]:
##############################################

In [None]:
fs_num = 25  # 定义额外特征的数量

In [None]:
#####################

In [None]:
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.att_weight = nn.Parameter(torch.Tensor(hidden_dim, 1))
        nn.init.xavier_uniform_(self.att_weight)
        self.tanh = nn.Tanh()

    def attention_layer(self, h, mask):
        att_weight = self.att_weight.expand(mask.shape[0], -1, -1)  # B*H*1
        att_score = torch.bmm(self.tanh(h), att_weight)  # B*L*H  *  B*H*1 -> B*L*1

        # mask, remove the effect of 'PAD'
        mask = mask.unsqueeze(dim=-1)  # B*L*1
        att_score = att_score.masked_fill(mask.eq(0), float('-inf'))  # B*L*1
        att_weight = F.softmax(att_score, dim=1)  # B*L*1

        reps = h * att_weight  # B*L*H *  B*L*1 -> B*L*H
        reps = self.tanh(reps)  # B*L*H
        return reps, att_weight


class BertNerModelWithAttention(nn.Module):
    def __init__(self, num_labels):
        super(BertNerModelWithAttention, self).__init__()
        self.bert = BertModel.from_pretrained(weight)
        self.dropout = nn.Dropout(0.1)
        self.layernorm = nn.LayerNorm(normalized_shape=768)  # 使用BERT的输出维度进行LayerNorm
        self.relu = nn.ReLU()
        self.linear_dropout = nn.Dropout(0.1)
        self.attention = Attention(768)
        self.classifier = nn.Linear(768 + 18, num_labels)  # 修改为18维

    def forward(self, input_ids, attention_mask, extra_features, token_type_ids=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        sequence_output = outputs[0]  # [batch_size, seq_len, hidden_dim]
        sequence_output = self.dropout(sequence_output)

        # 添加LayerNorm、ReLU和Dropout操作
        normalized_output = self.layernorm(sequence_output)
        activated_output = self.relu(normalized_output)
        dropout_output = self.linear_dropout(activated_output)

        context_vector, attention_weights = self.attention.attention_layer(dropout_output, attention_mask)  # [batch_size, seq_len, hidden_dim]

        # 只使用前18个ET特征
        et_features = extra_features[:, :, :18]  # 提取前18维度的特征

        # 直接拼接 context_vector 和 et_features
        combined_output = torch.cat((context_vector, et_features), dim=-1)  # [batch_size, seq_len, hidden_dim + 18]
        logits = self.classifier(combined_output)  # [batch_size, seq_len, num_labels]

        return logits, attention_weights


In [None]:
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from sklearn.metrics import f1_score

# model = BertNerModelWithSoftAttention(num_labels=6)
# model = model.to(device)


model = BertNerModelWithAttention(num_labels=6)
model = model.to(device)


#optim = AdamW(model.parameters(), lr=5e-5, weight_decay=1e-2)
optim = AdamW(model.parameters(), lr=1e-5, weight_decay=1e-2)  # 尝试降低学习率

loss_fn = CrossEntropyLoss(reduction='none', ignore_index=0)
loss_fn = loss_fn.to(device)

epochs = 5
best_f1 = 0.0
num_labels = 6

In [None]:
# 训练模型
import torch.nn.functional as F

epochs = 5
best_f1 = 0.0
for epoch in tqdm(range(epochs)):
    loss_value = 0.0
    model.train()
    label_true, label_pred = [], []
    for i, batch in enumerate(train_dataloader):
        optim.zero_grad()
        input_ids, attention_masks, _, features, tags = batch
        pred_tags, _ = model(input_ids.to(device), attention_masks.to(device), features.to(device))

        loss = loss_fn(pred_tags.permute(0, 2, 1), tags.to(device))
        loss = loss.mean()
        loss.backward()
        optim.step()

        pred_tags = F.softmax(pred_tags, dim=-1)
        pred_tags = torch.argmax(pred_tags, dim=-1)

        y_pred, y_true = calculate_f1(pred_tags, tags)
        label_true.extend(y_true)
        label_pred.extend(y_pred)

        loss_value += loss.item()

    label_train_f1 = f1_score(label_true, label_pred, average='macro')

    model.eval()
    kw_true, kw_pred = [], []
    label_true, label_pred = [], []
    for i, batch in enumerate(test_dataloader):
        input_ids, attention_masks, tokens, features, tags = batch
        with torch.no_grad():
            for module in model.modules():
                if isinstance(module, nn.Dropout):
                    module.p = 0
                    module.train(False)
            pred_tags, _ = model(input_ids.to(device), attention_masks.to(device), features.to(device))
            pred_tags = F.softmax(pred_tags, dim=-1)
            pred_tags = torch.argmax(pred_tags, dim=-1)

        y_pred, y_true = calculate_f1(pred_tags, tags)
        label_true.extend(y_true)
        label_pred.extend(y_pred)

        poss = []
        for i in range(len(tags)):
            pos = []
            for j in range(len(tags[i])):
                if tags[i][j] == 0:
                    pos.append(j)
            poss.append(pos)

        kw_true.extend(TagConvert(tags, tokens))
        kw_pred.extend(TagConvert(pred_tags, tokens, poss))

    label_f1 = f1_score(label_true, label_pred, average='macro')
    P, R, F1 = evaluate(kw_true, kw_pred)

    if F1 > best_f1:
        best_f1 = F1
        torch.save(model.state_dict(), './pretrain_pt/bert_with_soft_ET.pt')

    print("epoch{}:  loss:{:.2f}   train_f1_value:{:.2f}  test_f1_value:{:.2f}  kw_f1_value:{:.2f}".format(
        epoch + 1, loss_value / len(train_dataloader), label_train_f1, label_f1, F1
    ))


In [None]:
model = BertNerModelWithAttention(num_labels=6)
model.load_state_dict(torch.load('./pretrain_pt/bert_with_soft_ET.pt'))
model = model.to(device)

In [None]:
# print(P)
# print(R)
# print(F1)

In [None]:
import torch
import torch.nn.functional as F
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader
import torch.nn as nn

# 加载最佳模型权重
model.load_state_dict(torch.load('./pretrain_pt/bert_with_soft_ET.pt'))
model.eval()


def inference_and_evaluate(test_dataloader, model, device):
    kw_true, kw_pred = [], []
    label_true, label_pred = [], []

    for i, batch in enumerate(test_dataloader):
        input_ids, attention_masks, tokens, features, tags = batch
        with torch.no_grad():
            for module in model.modules():
                if isinstance(module, nn.Dropout):
                    module.p = 0
                    module.train(False)
            outputs = model(input_ids.to(device), attention_masks.to(device), features.to(device))
            pred_tags = outputs[0] if isinstance(outputs, tuple) else outputs  # Handle tuple output
            pred_tags = F.softmax(pred_tags, dim=-1)
            pred_tags = torch.argmax(pred_tags, dim=-1)

        y_pred, y_true = calculate_f1(pred_tags, tags)
        label_true.extend(y_true)
        label_pred.extend(y_pred)

        # more balance evaluate
        poss = []
        for i in range(len(tags)):
            pos = []
            for j in range(len(tags[i])):
                if tags[i][j] == 0:
                    pos.append(j)
            poss.append(pos)

        kw_true.extend(TagConvert(tags, tokens))
        kw_pred.extend(TagConvert(pred_tags, tokens, poss))

    label_f1 = f1_score(label_true, label_pred, average='macro')
    P, R, F1 = evaluate(kw_true, kw_pred)

    return label_f1, P, R, F1

# 调用推理和评价函数
label_f1, P, R, F1 = inference_and_evaluate(test_dataloader, model, device)

print(f"Label F1 Score: {label_f1:.2f}")
print(f"Precision: {P:.2f}")
print(f"Recall: {R:.2f}")
print(f"F1 Score: {F1:.2f}")


 ###定义词级别和句子级别的注意力层

层注意力 2.0

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel
from sklearn.metrics import f1_score
from tqdm import tqdm
from torch.utils.data import DataLoader

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = BertHANModel(num_labels=6)
# model = model.to(device)
#
# optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
# loss_fn = nn.CrossEntropyLoss().to(device)


In [None]:
# import torch
# import torch.nn.functional as F
# from torch.cuda.amp import GradScaler, autocast
# from tqdm import tqdm
# from sklearn.metrics import f1_score
#
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = BertHANModel(num_labels=6)
# model = model.to(device)
#
# optim = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=1e-2)
# loss_fn = nn.CrossEntropyLoss(reduction='none', ignore_index=0)
# loss_fn = loss_fn.to(device)
#
# scaler = GradScaler()


In [None]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import f1_score
from transformers import BertModel
from torch import nn
from torch.cuda.amp import autocast, GradScaler

class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.context = nn.Parameter(torch.FloatTensor(hidden_dim, 1))
        nn.init.xavier_uniform_(self.context)

    def forward(self, x, mask=None):
        attention_in = torch.tanh(torch.matmul(x, self.context))
        attention_in = torch.squeeze(attention_in, -1)
        if mask is not None:
            attention_in = attention_in * mask.float()
        attention_weights = F.softmax(attention_in, dim=-1)
        weighted_sum = torch.bmm(attention_weights.unsqueeze(1), x).squeeze(1)
        return weighted_sum

class BertHANModel(nn.Module):
    def __init__(self, num_labels, hidden_dim=768, rnn_dim=256):
        super(BertHANModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.word_attention = Attention(hidden_dim)
        self.rnn = nn.GRU(hidden_dim, rnn_dim, batch_first=True, bidirectional=True)
        self.sentence_attention = Attention(rnn_dim * 2)
        self.classifier = nn.Linear(rnn_dim * 2 + 17, num_labels)  # 修改为17维

    def forward(self, input_ids, attention_mask, extra_features):
        bert_outputs = self.bert(input_ids, attention_mask=attention_mask)
        sequence_output = bert_outputs[0]  # shape: (batch_size, seq_length, hidden_dim)

        # Word-level attention
        word_attention_output = self.word_attention(sequence_output)

        # Sentence-level GRU
        rnn_output, _ = self.rnn(word_attention_output.unsqueeze(1))

        # Sentence-level attention
        sentence_attention_output = self.sentence_attention(rnn_output)

        # 确保 sentence_attention_output 是三维张量
        if sentence_attention_output.dim() == 2:
            # 将其调整为三维张量 [batch_size, seq_len, rnn_dim * 2]
            sentence_attention_output = sentence_attention_output.unsqueeze(1).expand(-1, extra_features.size(1), -1)

        # 只使用前 17 个 ET 特征
        eye_tracking_features = extra_features[:, :, :17]  # [batch_size, seq_len, 17]

        # 打印维度以调试
        # print(f"Adjusted sentence_attention_output shape: {sentence_attention_output.shape}")
        # print(f"eye_tracking_features shape: {eye_tracking_features.shape}")

        # 拼接特征
        combined_output = torch.cat((sentence_attention_output, eye_tracking_features), dim=-1)  # [batch_size, seq_len, rnn_dim * 2 + 17]

        logits = self.classifier(combined_output)  # shape: (batch_size, seq_len, num_labels)

        return logits  # 返回 shape: (batch_size, seq_length, num_labels)





def calculate_f1(y_pred, y_true):
    y_true = y_true.view(-1)
    y_pred = y_pred.view(-1)
    y_true = y_true.detach().cpu().numpy()
    y_pred = y_pred.detach().cpu().numpy()
    mask = np.where(y_true != 0)
    y_true = y_true[mask]
    y_pred = y_pred[mask]
    return y_pred, y_true

def TagConvert(raw_tags, words_set, poss=None):
    true_tags = []
    for i in range(raw_tags.shape[0]):
        kw_list = []
        nkw_list = ""
        for j in range(len(raw_tags[i])):
            item = raw_tags[i][j]
            if item == 0:
                continue
            if poss != None and j in poss[i]:
                continue
            if item == 4:
                kw_list.append(str(words_set[j][i]))
            if item == 1:
                nkw_list += str(words_set[j][i])
            if item == 2:
                nkw_list += " "
                nkw_list += str(words_set[j][i])
            if item == 3:
                nkw_list += " "
                nkw_list += str(words_set[j][i])
                kw_list.append(nkw_list)
                nkw_list = ""
        true_tags.append(kw_list)
    return true_tags

def evaluate(predict_data, target_data, topk=3):
    TRUE_COUNT, PRED_COUNT, GOLD_COUNT = 0.0, 0.0, 0.0
    for index, words in enumerate(predict_data):
        y_pred, y_true = None, target_data[index]
        if type(predict_data) == str:
            words = sorted(words.items(), key=lambda item: (-item[1], item[0]))
            y_pred = [i[0] for i in words]
        elif type(predict_data) == list:
            y_pred = words
        y_pred = y_pred[0: topk]
        TRUE_NUM = len(set(y_pred) & set(y_true))
        TRUE_COUNT += TRUE_NUM
        PRED_COUNT += len(y_pred)
        GOLD_COUNT += len(y_true)
    if PRED_COUNT != 0:
        p = (TRUE_COUNT / PRED_COUNT)
    else:
        p = 0
    if GOLD_COUNT != 0:
        r = (TRUE_COUNT / GOLD_COUNT)
    else:
        r = 0
    if (r + p) != 0:
        f1 = ((2 * r * p) / (r + p))
    else:
        f1 = 0
    p = round(p * 100, 2)
    r = round(r * 100, 2)
    f1 = round(f1 * 100, 2)
    return p, r, f1

# 假设已经定义了数据集和数据加载器
# train_dataloader = ...
# test_dataloader = ...

# 训练和评估
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BertHANModel(num_labels=6)
model = model.to(device)
optim = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=1e-2)
loss_fn = nn.CrossEntropyLoss(reduction='none', ignore_index=0)
loss_fn = loss_fn.to(device)

epochs = 5
best_f1 = 0.0
scaler = GradScaler()

for epoch in range(epochs):
    loss_value = 0.0
    model.train()
    label_true, label_pred = [], []
    for i, batch in enumerate(train_dataloader):
        optim.zero_grad()
#input_ids, attention_masks, _, features, tags = batch
        input_ids = batch[0].to(device)
        attention_masks = batch[1].to(device)
        features = batch[3].to(device)
        tags = batch[4].to(device)

        with autocast():
            pred_tags = model(input_ids, attention_masks,features)

            # 展平 pred_tags 和 tags 以匹配形状
            pred_tags = pred_tags.reshape(-1, pred_tags.size(-1))
            tags = tags.reshape(-1)

            #print(f"pred_tags shape: {pred_tags.shape}, tags shape: {tags.shape}")

            loss = loss_fn(pred_tags, tags)
            loss = loss.mean()

        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()

        pred_tags = F.softmax(pred_tags, dim=-1)
        pred_tags = torch.argmax(pred_tags, dim=-1)
        y_pred, y_true = calculate_f1(pred_tags, tags)
        label_true.extend(y_true)
        label_pred.extend(y_pred)
        loss_value += loss.item()

    label_train_f1 = f1_score(label_true, label_pred, average='macro')

    model.eval()
    kw_true, kw_pred = [], []
    label_true, label_pred = [], []
    for i, batch in enumerate(test_dataloader):
        input_ids = batch[0].to(device)
        attention_masks = batch[1].to(device)
        tokens = batch[2]  # tokens 不是 Tensor，直接使用
        features = batch[3].to(device)
        tags = batch[4].to(device)

        with torch.no_grad():
            for module in model.modules():
                if isinstance(module, nn.Dropout):
                    module.p = 0
                    module.train(False)
            with autocast():
                pred_tags = model(input_ids, attention_masks,features)
                pred_tags = F.softmax(pred_tags, dim=-1)
                pred_tags = torch.argmax(pred_tags, dim=-1)

        y_pred, y_true = calculate_f1(pred_tags, tags)
        label_true.extend(y_true)
        label_pred.extend(y_pred)

        poss = []
        for i in range(len(tags)):
            pos = []
            for j in range(len(tags[i])):
                if tags[i][j] == 0:
                    pos.append(j)
            poss.append(pos)

        kw_true.extend(TagConvert(tags, tokens))
        kw_pred.extend(TagConvert(pred_tags, tokens, poss))

    label_f1 = f1_score(label_true, label_pred, average='macro')
    P, R, F1 = evaluate(kw_true, kw_pred)

    if F1 > best_f1:
        best_f1 = F1
        torch.save(model.state_dict(), './pretrain_pt/bert_HAtten_ET.pt')

    print("epoch{}:  loss:{:.2f}   train_f1_value:{:.2f}  test_f1_value:{:.2f}  kw_f1_value:{:.2f}".format(
        epoch + 1, loss_value / len(train_dataloader), label_train_f1, label_f1, F1
    ))

    torch.cuda.empty_cache()


In [None]:
sentence_ids = features[:, :, -1]
print("Sentence IDs Shape:", sentence_ids.shape)
print("Sentence IDs Example:", sentence_ids[0])


In [None]:
def load_model(model_path, num_labels):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = BertHANModel(num_labels=num_labels)
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)
    model.eval()
    return model

def calculate_f1(y_pred, y_true):
    y_true = y_true.view(-1)
    y_pred = y_pred.view(-1)
    y_true = y_true.detach().cpu().numpy()
    y_pred = y_pred.detach().cpu().numpy()
    mask = np.where(y_true != 0)
    y_true = y_true[mask]
    y_pred = y_pred[mask]
    return y_pred, y_true

def TagConvert(raw_tags, words_set, poss=None):
    true_tags = []
    for i in range(raw_tags.shape[0]):
        kw_list = []
        nkw_list = ""
        for j in range(len(raw_tags[i])):
            item = raw_tags[i][j]
            if item == 0:
                continue
            if poss != None and j in poss[i]:
                continue
            if item == 4:
                kw_list.append(str(words_set[j][i]))
            if item == 1:
                nkw_list += str(words_set[j][i])
            if item == 2:
                nkw_list += " "
                nkw_list += str(words_set[j][i])
            if item == 3:
                nkw_list += " "
                nkw_list += str(words_set[j][i])
                kw_list.append(nkw_list)
                nkw_list = ""
        true_tags.append(kw_list)
    return true_tags

def evaluate(predict_data, target_data, topk=3):
    TRUE_COUNT, PRED_COUNT, GOLD_COUNT = 0.0, 0.0, 0.0
    for index, words in enumerate(predict_data):
        y_pred, y_true = None, target_data[index]
        if type(predict_data) == str:
            words = sorted(words.items(), key=lambda item: (-item[1], item[0]))
            y_pred = [i[0] for i in words]
        elif type(predict_data) == list:
            y_pred = words
        y_pred = y_pred[0: topk]
        TRUE_NUM = len(set(y_pred) & set(y_true))
        TRUE_COUNT += TRUE_NUM
        PRED_COUNT += len(y_pred)
        GOLD_COUNT += len(y_true)
    if PRED_COUNT != 0:
        p = (TRUE_COUNT / PRED_COUNT)
    else:
        p = 0
    if GOLD_COUNT != 0:
        r = (TRUE_COUNT / GOLD_COUNT)
    else:
        r = 0
    if (r + p) != 0:
        f1 = ((2 * r * p) / (r + p))
    else:
        f1 = 0
    p = round(p * 100, 2)
    r = round(r * 100, 2)
    f1 = round(f1 * 100, 2)
    return p, r, f1

def predict_and_evaluate(model, dataloader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    label_true, label_pred = [], []
    kw_true, kw_pred = [], []

    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch[0].to(device)
            attention_masks = batch[1].to(device)
            tokens = batch[2]  # tokens 不是 Tensor，直接使用
            features = batch[3].to(device)
            tags = batch[4].to(device)

            pred_tags = model(input_ids, attention_masks,features)
            pred_tags = torch.argmax(pred_tags, dim=-1)

            y_pred, y_true = calculate_f1(pred_tags, tags)
            label_true.extend(y_true)
            label_pred.extend(y_pred)

            poss = []
            for i in range(len(tags)):
                pos = []
                for j in range(len(tags[i])):
                    if tags[i][j] == 0:
                        pos.append(j)
                poss.append(pos)
            kw_true.extend(TagConvert(tags, tokens))
            kw_pred.extend(TagConvert(pred_tags, tokens, poss))

    label_f1 = f1_score(label_true, label_pred, average='macro')
    P, R, F1 = evaluate(kw_true, kw_pred)
    return label_f1, P, R, F1

# 加载模型
model_path = './pretrain_pt/bert_HAtten_ET.pt'
num_labels = 6
model = load_model(model_path, num_labels)

# 假设 test_dataloader 已经定义好
label_f1, P, R, F1 = predict_and_evaluate(model, test_dataloader)

print(f"label_f1: {label_f1:.2f}, Precision: {P:.2f}, Recall: {R:.2f}, F1: {F1:.2f}")