In [None]:
import os
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
from dataset import EurDataset, collate_data
from models.transceiver import DeepSC
from torch.utils.data import DataLoader
from utils import BleuScore, SNR_to_noise, greedy_decode, SeqtoText
from tqdm import tqdm
from sklearn.preprocessing import normalize
from w3lib.html import remove_tags

# Define default values as variables
data_dir = 'data/europarl/train_data.pkl'
vocab_file = 'data/europarl/vocab.json'
checkpoint_path = 'checkpoints/deepsc-Rayleigh'
channel = 'Rayleigh'
# checkpoint_path = 'checkpoints/deepsc-AWGN'
# channel = 'AWGN'
MAX_LENGTH = 30
MIN_LENGTH = 4
d_model = 128
dff = 256
num_layers = 3
num_heads = 8
batch_size = 64
epochs = 2
bert_config_path = 'bert/cased_L-12_H-768_A-12/bert_config.json'
bert_checkpoint_path = 'bert/cased_L-12_H-768_A-12/bert_model.ckpt'
bert_dict_path = 'bert/cased_L-12_H-768_A-12/vocab.txt'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Function to calculate performance
def performance(SNR, net, vocab_file, batch_size, epochs, MAX_LENGTH, channel):
    # bleu_scores = [BleuScore(1, 0, 0, 0)]

    bleu_scores = [BleuScore(1, 0, 0, 0), BleuScore(0, 1, 0, 0), BleuScore(0, 0, 1, 0), BleuScore(0, 0, 0, 1)]


    test_eur = EurDataset('test')
    test_iterator = DataLoader(test_eur, batch_size=batch_size, num_workers=0,
                               pin_memory=True, collate_fn=collate_data)

    StoT = SeqtoText(token_to_idx, end_idx)
    score = [[] for _ in range(len(bleu_scores))]
    net.eval()
    with torch.no_grad():
        for epoch in range(epochs):
            Tx_word = []
            Rx_word = []

            for snr in tqdm(SNR):
                word = []
                target_word = []
                noise_std = SNR_to_noise(snr)

                for sents in test_iterator:
                    sents = sents.to(device)
                    target = sents

                    out = greedy_decode(net, sents, noise_std, MAX_LENGTH, pad_idx,
                                        start_idx, channel)

                    sentences = out.cpu().numpy().tolist()
                    result_string = list(map(StoT.sequence_to_text, sentences))
                    word = word + result_string

                    target_sent = target.cpu().numpy().tolist()
                    result_string = list(map(StoT.sequence_to_text, target_sent))
                    target_word = target_word + result_string

                Tx_word.append(word)
                Rx_word.append(target_word)

            for i, bleu_score in enumerate(bleu_scores):
                bleu = []
                for sent1, sent2 in zip(Tx_word, Rx_word):
                    bleu.append(bleu_score.compute_blue_score(sent1, sent2))
                bleu = np.array(bleu)
                bleu = np.mean(bleu, axis=1)
                score[i].append(bleu)

    score = [np.mean(np.array(s), axis=0) for s in score]

    return score

# Define SNR values
SNR = [0, 3, 6, 9, 12, 15, 18]

vocab = json.load(open(vocab_file, 'rb'))
token_to_idx = vocab['token_to_idx']
idx_to_token = dict(zip(token_to_idx.values(), token_to_idx.keys()))
num_vocab = len(token_to_idx)
pad_idx = token_to_idx["<PAD>"]
start_idx = token_to_idx["<START>"]
end_idx = token_to_idx["<END>"]

# Initialize and load model
deepsc = DeepSC(num_layers, num_vocab, num_vocab,
                num_vocab, num_vocab, d_model, num_heads,
                dff, 0.1).to(device)

model_paths = []
for fn in os.listdir(checkpoint_path):
    if not fn.endswith('.pth'): continue
    idx = int(os.path.splitext(fn)[0].split('_')[-1])
    model_paths.append((os.path.join(checkpoint_path, fn), idx))

model_paths.sort(key=lambda x: x[1])
model_path, _ = model_paths[-1]
checkpoint = torch.load(model_path)
deepsc.load_state_dict(checkpoint)
print('Model loaded!')

# Calculate BLEU scores
bleu_scores = performance(SNR, deepsc, vocab_file, batch_size, epochs, MAX_LENGTH, channel)
for i, bleu in enumerate(bleu_scores):
    print(f'{i+1}-gram BLEU score:', bleu)




# Experiments below

In [None]:
# Plot BLEU scores for different n-grams
def plot_bleu_scores(SNR, bleu_scores):
    plt.figure(figsize=(20, 12))

    n_grams = [1, 2, 3, 4]
    for i, bleu in enumerate(bleu_scores):
        print(f'{n_grams[i]}-gram BLEU score:', bleu)
        plt.subplot(2, 2, i+1)
        plt.plot(SNR, bleu, marker='o')
        plt.xlabel('SNR (dB)')
        plt.ylabel('BLEU Score')
        plt.title(f'{n_grams[i]}-gram BLEU Score vs SNR')
        plt.grid(True)

    plt.tight_layout()
    plt.show()

# Call the plot function
plot_bleu_scores(SNR, bleu_scores)