In [83]:
# Some setup code for imports
import os
import sys
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from custom_datasets.BayesRiskDatasetLoader import BayesRiskDatasetLoader

In [84]:

# First we will load the test set with the calculated scores.
dataset_loader = BayesRiskDatasetLoader("validation_predictive", n_hypotheses=100, n_references=1000, sampling_method='ancestral')
dataset = dataset_loader.load(type="pandas")

In [85]:
### Next up we will read the trained model and calculate the score given by the heads
from models.pl_predictive.PLPredictiveModelFactory import PLPredictiveModelFactory
path = "C:/Users/gerso/FBR/predictive/tatoeba-de-en/models/mixture_model_2/"
model, factory = PLPredictiveModelFactory.load(path)

C:\Users\gerso\FBR\NMT/tatoeba-de-en/model
1.001726887106992e-06
using a mixture model


In [86]:
import numpy as np
import torch

import torch.distributions as td

from models.MBR_model.BaseMBRModel import BaseMBRModel
from utils.translation_model_utils import batch_sample, batch


class GaussianMixtureMBRModel(BaseMBRModel):
    '''
    Model wraps the NMT model
    '''

    def __init__(self, predictive_model, device="cuda", sample_size=1000):
        super().__init__(predictive_model)
        self.device_name = device
        self.predictive_model = predictive_model.to(self.device_name)
        self.sample_size = sample_size

    def forward(self, source, n_samples_per_source=256, batch_size=16, ):
        hypotheses = batch_sample(self.predictive_model.nmt_model, self.predictive_model.tokenizer, [source],
                                  n_samples=n_samples_per_source, )

        return self.get_best(source, hypotheses)

    def get_scores(self, sources, samples, batch_size=16):

        model_out = self.get_model_out(sources, samples, batch_size=batch_size)

        scores = self.model_out_to_risk(model_out)

        return scores

    def get_model_out(self, sources, samples, batch_size=16):
        result = {}
        for x, y in zip(batch(sources, n=batch_size), batch(samples, n=batch_size)):
            model_out = self.predictive_model.predict(x, y)
            result = self.add_model_out_to_result(result, model_out)
        result = {k: torch.tensor(v) for k, v in result.items()}

        return result

    def get_best(self, source, hypotheses, batch_size=16):
        sources = [source] * len(hypotheses)

        scores = self.get_scores(sources, hypotheses, batch_size=batch_size)

        best_index = np.argmax(scores)

        best = hypotheses[best_index]

        return best

    def model_out_to_risk(self, model_out):

        mixture = self.get_mixture(model_out["loc"], model_out["scale"], model_out["logits"])

        sample_scores = mixture.sample()

        mean = sample_scores.mean(-1)
        return mean

    def get_mean(self, sources, samples, batch_size=16):
        model_out = self.get_model_out(sources, samples, batch_size=batch_size)
        mixture = self.get_mixture(model_out["loc"], model_out["scale"], model_out["logits"])
        
        sample_scores = mixture.sample((100000,))
        print(sample_scores.shape)
        mean = sample_scores.mean(0)

        return mean


    def add_model_out_to_result(self, result, model_out):
        if result == {}:
            result["loc"] = []
            result["scale"] = []
            result["logits"] = []
        result['loc'] += list(model_out[0].cpu().numpy())
        result['scale'] += list(model_out[1].cpu().numpy())
        result['logits'] += list(model_out[2].cpu().numpy())

        return result

    def get_mixture(self, loc, scale, logits):
        components = self.make_components(loc, scale)
        mixture = td.MixtureSameFamily(td.Categorical(logits=logits), components)
        
        return mixture

    def make_components(self, loc, scale):
        sample_size = self.sample_size
        shape = loc.shape
        loc = loc.unsqueeze(-1)
        scale = scale.unsqueeze(-1)
        return td.Independent(td.Normal(loc=loc, scale=scale), 1)

In [87]:
wrapped_model = GaussianMixtureMBRModel(model)

In [90]:
### Here we get the predicted loc and scale for each sentence:
from tqdm import tqdm
all_model_out = []
for i, row in tqdm(dataset.data.iterrows(), total=dataset.data.shape[0]):
    source = row["source"]
    hypotheses = list(row["hypotheses"])[:2]

    model_out = wrapped_model.get_mean([source]* len(hypotheses), hypotheses)
    all_model_out.append(model_out)
    if i > 10:
        break

  0%|▍                                                                                                                                                                                                     | 5/2500 [00:00<00:51, 48.08it/s]

torch.Size([10000, 2, 1])
torch.Size([10000, 2, 1])
torch.Size([10000, 2, 1])
torch.Size([10000, 2, 1])
torch.Size([10000, 2, 1])
torch.Size([10000, 2, 1])
torch.Size([10000, 2, 1])
torch.Size([10000, 2, 1])
torch.Size([10000, 2, 1])
torch.Size([10000, 2, 1])


  0%|▊                                                                                                                                                                                                    | 11/2500 [00:00<00:57, 42.93it/s]

torch.Size([10000, 2, 1])
torch.Size([10000, 2, 1])





In [None]:
# Save for later

## Next up we check if we are optimistic or pessimistic.

def calc_opt_statistics(predicted_scores, reference_scores):
    opt_count = 0
    pes_count = 0
    opt_sum = 0
    pes_sum = 0
    for pred_score, ref_score in zip(predicted_scores, reference_scores):
        larger = pred_score >= ref_score
        smaller = ref_score > pred_score
        opt_count += sum(larger)
        pes_count += sum(smaller)
        opt_sum += sum((pred_score - ref_score) * larger)
        pes_sum += sum((ref_score - pred_score) * smaller)
    total = opt_count + pes_count
    opt_percentage = opt_count/total
    pes_percentage = pes_count/total
    opt_avg = opt_sum/opt_count
    pes_avg = pes_sum/pes_count
        
    return {"opt_count": opt_count, "pes_count": pes_count, 
            "opt_sum": opt_sum, "pes_sum": pes_sum,
            "opt_percentage": opt_percentage,
            "pes_percentage": pes_percentage,
            "opt_avg": opt_avg,
            "pes_avg": pes_avg,
           }
