# Homework and bake-off: Word similarity

In [1]:
__author__ = "Christopher Potts"
__version__ = "CS224u, Stanford, Spring 2020"

## Contents

1. [Overview](#Overview)
1. [Set-up](#Set-up)
1. [Dataset readers](#Dataset-readers)
1. [Dataset comparisons](#Dataset-comparisons)
  1. [Vocab overlap](#Vocab-overlap)
  1. [Pair overlap and score correlations](#Pair-overlap-and-score-correlations)
1. [Evaluation](#Evaluation)
  1. [Dataset evaluation](#Dataset-evaluation)
  1. [Dataset error analysis](#Dataset-error-analysis)
  1. [Full evaluation](#Full-evaluation)
1. [Homework questions](#Homework-questions)
  1. [PPMI as a baseline [0.5 points]](#PPMI-as-a-baseline-[0.5-points])
  1. [Gigaword with LSA at different dimensions [0.5 points]](#Gigaword-with-LSA-at-different-dimensions-[0.5-points])
  1. [Gigaword with GloVe for a small number of iterations [0.5 points]](#Gigaword-with-GloVe-for-a-small-number-of-iterations-[0.5-points])
  1. [Dice coefficient [0.5 points]](#Dice-coefficient-[0.5-points])
  1. [t-test reweighting [2 points]](#t-test-reweighting-[2-points])
  1. [Enriching a VSM with subword information [2 points]](#Enriching-a-VSM-with-subword-information-[2-points])
  1. [Your original system [3 points]](#Your-original-system-[3-points])
1. [Bake-off [1 point]](#Bake-off-[1-point])

## Overview

Word similarity datasets have long been used to evaluate distributed representations. This notebook provides basic code for conducting such analyses with a number of datasets:

| Dataset | Pairs | Task-type | Current best Spearman $\rho$ | Best $\rho$ paper |   |
|---------|-------|-----------|------------------------------|-------------------|---|
| [WordSim-353](http://www.cs.technion.ac.il/~gabr/resources/data/wordsim353/) | 353 | Relatedness | 82.8 | [Speer et al. 2017](https://arxiv.org/abs/1612.03975) |
| [MTurk-771](http://www2.mta.ac.il/~gideon/mturk771.html) | 771 | Relatedness | 81.0 | [Speer et al. 2017](https://arxiv.org/abs/1612.03975) |
| [The MEN Test Collection](http://clic.cimec.unitn.it/~elia.bruni/MEN) | 3,000 | Relatedness | 86.6 | [Speer et al. 2017](https://arxiv.org/abs/1612.03975)  | 
| [SimVerb-3500-dev](http://people.ds.cam.ac.uk/dsg40/simverb.html) | 500 | Similarity | 61.1 | [Mrki&scaron;&cacute; et al. 2016](https://arxiv.org/pdf/1603.00892.pdf) |
| [SimVerb-3500-test](http://people.ds.cam.ac.uk/dsg40/simverb.html) | 3,000 | Similarity | 62.4 | [Mrki&scaron;&cacute; et al. 2016](https://arxiv.org/pdf/1603.00892.pdf) |

Each of the similarity datasets contains word pairs with an associated human-annotated similarity score. (We convert these to distances to align intuitively with our distance measure functions.) The evaluation code measures the distance between the word pairs in your chosen VSM (which should be a `pd.DataFrame`).

The evaluation metric for each dataset is the [Spearman correlation coefficient $\rho$](https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient) between the annotated scores and your distances, as is standard in the literature. We also macro-average these correlations across the datasets for an overall summary. (In using the macro-average, we are saying that we care about all the datasets equally, even though they vary in size.)

This homework ([questions at the bottom of this notebook](#Homework-questions)) asks you to write code that uses the count matrices in `data/vsmdata` to create and evaluate some baseline models as well as an original model $M$ that you design. This accounts for 9 of the 10 points for this assignment.

For the associated bake-off, we will distribute two new word similarity or relatedness datasets and associated reader code, and you will evaluate $M$ (no additional training or tuning allowed!) on those new datasets. Systems that enter will receive the additional homework point, and systems that achieve the top score will receive an additional 0.5 points.

## Set-up

In [2]:
from collections import defaultdict
import csv
import itertools
import numpy as np
import os
import pandas as pd
from scipy.stats import spearmanr
import vsm
from IPython.display import display

In [3]:
VSM_HOME = os.path.join('data', 'vsmdata')

WORDSIM_HOME = os.path.join('data', 'wordsim')

## Dataset readers

In [4]:
def wordsim_dataset_reader(
        src_filename, 
        header=False, 
        delimiter=',', 
        score_col_index=2):
    """Basic reader that works for all similarity datasets. They are 
    all tabular-style releases where the first two columns give the 
    word and a later column (`score_col_index`) gives the score.

    Parameters
    ----------
    src_filename : str
        Full path to the source file.
    header : bool
        Whether `src_filename` has a header. Default: False
    delimiter : str
        Field delimiter in `src_filename`. Default: ','
    score_col_index : int
        Column containing the similarity scores Default: 2

    Yields
    ------
    (str, str, float)
       (w1, w2, score) where `score` is the negative of the similarity
       score in the file so that we are intuitively aligned with our
       distance-based code. To align with our VSMs, all the words are 
       downcased.

    """
    with open(src_filename) as f:
        reader = csv.reader(f, delimiter=delimiter)
        if header:
            next(reader)
        for row in reader:
            w1 = row[0].strip().lower()
            w2 = row[1].strip().lower()
            score = row[score_col_index]
            # Negative of scores to align intuitively with distance functions:
            score = -float(score)
            yield (w1, w2, score)

def wordsim353_reader():
    """WordSim-353: http://www.cs.technion.ac.il/~gabr/resources/data/wordsim353/"""
    src_filename = os.path.join(
        WORDSIM_HOME, 'wordsim353', 'combined.csv')
    return wordsim_dataset_reader(
        src_filename, header=True)

def mturk771_reader():
    """MTURK-771: http://www2.mta.ac.il/~gideon/mturk771.html"""
    src_filename = os.path.join(
        WORDSIM_HOME, 'MTURK-771.csv')
    return wordsim_dataset_reader(
        src_filename, header=False)

def simverb3500dev_reader():
    """SimVerb-3500: http://people.ds.cam.ac.uk/dsg40/simverb.html"""
    src_filename = os.path.join(
        WORDSIM_HOME, 'SimVerb-3500', 'SimVerb-500-dev.txt')
    return wordsim_dataset_reader(
        src_filename, delimiter="\t", header=False, score_col_index=3)

def simverb3500test_reader():
    """SimVerb-3500: http://people.ds.cam.ac.uk/dsg40/simverb.html"""
    src_filename = os.path.join(
        WORDSIM_HOME, 'SimVerb-3500', 'SimVerb-3000-test.txt')
    return wordsim_dataset_reader(
        src_filename, delimiter="\t", header=False, score_col_index=3)

def men_reader():
    """MEN: http://clic.cimec.unitn.it/~elia.bruni/MEN"""
    src_filename = os.path.join(
        WORDSIM_HOME, 'MEN', 'MEN_dataset_natural_form_full')
    return wordsim_dataset_reader(
        src_filename, header=False, delimiter=' ') 

This collection of readers will be useful for flexible evaluations:

In [5]:
READERS = (wordsim353_reader, mturk771_reader, simverb3500dev_reader, 
           simverb3500test_reader, men_reader)

## Dataset comparisons

This section does some basic analysis of the datasets. The goal is to obtain a deeper understanding of what problem we're solving – what strengths and weaknesses the datasets have and how they relate to each other. For a full-fledged project, we would want to continue work like this and report on it in the paper, to provide context for the results.

In [6]:
def get_reader_name(reader):
    """Return a cleaned-up name for the similarity dataset 
    iterator `reader`
    """
    return reader.__name__.replace("_reader", "")

### Vocab overlap

How many vocabulary items are shared across the datasets?

In [7]:
def get_reader_vocab(reader):
    """Return the set of words (str) in `reader`."""
    vocab = set()
    for w1, w2, _ in reader():
        vocab.add(w1)
        vocab.add(w2)
    return vocab

In [8]:
def get_reader_vocab_overlap(readers=READERS):
    """Get data on the vocab-level relationships between pairs of 
    readers. Returns a a pd.DataFrame containing this information.
    """
    data = []
    for r1, r2 in itertools.product(readers, repeat=2):       
        v1 = get_reader_vocab(r1)
        v2 = get_reader_vocab(r2)
        d = {
            'd1': get_reader_name(r1),
            'd2': get_reader_name(r2),
            'overlap': len(v1 & v2), 
            'union': len(v1 | v2),
            'd1_size': len(v1),
            'd2_size': len(v2)}
        data.append(d)
    return pd.DataFrame(data)

In [9]:
vocab_overlap = get_reader_vocab_overlap()

In [10]:
def vocab_overlap_crosstab(vocab_overlap):
    """Return an intuitively formatted `pd.DataFrame` giving 
    vocab-overlap counts for all the datasets represented in 
    `vocab_overlap`, the output of `get_reader_vocab_overlap`.
    """        
    xtab = pd.crosstab(
        vocab_overlap['d1'], 
        vocab_overlap['d2'], 
        values=vocab_overlap['overlap'], 
        aggfunc=np.mean)
    # Blank out the upper right to reduce visual clutter:
    for i in range(0, xtab.shape[0]):
        for j in range(i+1, xtab.shape[1]):
            xtab.iloc[i, j] = ''        
    return xtab        

In [11]:
vocab_overlap_crosstab(vocab_overlap)

d2,men,mturk771,simverb3500dev,simverb3500test,wordsim353
d1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
men,751,,,,
mturk771,230,1113.0,,,
simverb3500dev,23,67.0,536.0,,
simverb3500test,30,94.0,532.0,823.0,
wordsim353,86,158.0,13.0,17.0,437.0


This looks reasonable. By design, the SimVerb dev and test sets have a lot of overlap. The other overlap numbers are pretty small, even adjusting for dataset size.

### Pair overlap and score correlations

How many word pairs are shared across datasets and, for shared pairs, what is the correlation between their scores? That is, do the datasets agree?

In [12]:
def get_reader_pairs(reader):
    """Return the set of alphabetically-sorted word (str) tuples 
    in `reader`
    """
    return {tuple(sorted([w1, w2])): score for w1, w2, score in reader()}

In [13]:
def get_reader_pair_overlap(readers=READERS):
    """Return a `pd.DataFrame` giving the number of overlapping 
    word-pairs in pairs of readers, along with the Spearman 
    correlations.
    """    
    data = []
    for r1, r2 in itertools.product(READERS, repeat=2):
        if r1.__name__ != r2.__name__:
            d1 = get_reader_pairs(r1)
            d2 = get_reader_pairs(r2)
            overlap = []
            for p, s in d1.items():
                if p in d2:
                    overlap.append([s, d2[p]])
            if overlap:
                s1, s2 = zip(*overlap)
                rho = spearmanr(s1, s2)[0]
            else:
                rho = None
            # Canonical order for the pair:
            n1, n2 = sorted([get_reader_name(r1), get_reader_name(r2)])
            d = {
                'd1': n1,
                'd2': n2,
                'pair_overlap': len(overlap),
                'rho': rho}
            data.append(d)
    df = pd.DataFrame(data)
    df = df.sort_values(['pair_overlap','d1','d2'], ascending=False)
    # Return only every other row to avoid repeats:
    return df[::2].reset_index(drop=True)

In [14]:
if 'IS_GRADESCOPE_ENV' not in os.environ:
    display(get_reader_pair_overlap())

Unnamed: 0,d1,d2,pair_overlap,rho
0,men,mturk771,11,0.592191
1,men,wordsim353,5,0.7
2,mturk771,simverb3500test,4,0.4
3,men,simverb3500test,2,1.0
4,simverb3500dev,simverb3500test,1,
5,simverb3500test,wordsim353,0,
6,simverb3500dev,wordsim353,0,
7,mturk771,wordsim353,0,
8,mturk771,simverb3500dev,0,
9,men,simverb3500dev,0,


This looks reasonable: none of the datasets have a lot of overlapping pairs, so we don't have to worry too much about places where they give conflicting scores.

## Evaluation

This section builds up the evaluation code that you'll use for the homework and bake-off. For illustrations, I'll read in a VSM created from `data/vsmdata/giga_window5-scaled.csv.gz`:

In [15]:
giga5 = pd.read_csv(
    os.path.join(VSM_HOME, "giga_window5-scaled.csv.gz"), index_col=0)

### Dataset evaluation

In [16]:
def word_similarity_evaluation(reader, df, distfunc=vsm.cosine):
    """Word-similarity evalution framework.
    
    Parameters
    ----------
    reader : iterator
        A reader for a word-similarity dataset. Just has to yield
        tuples (word1, word2, score).    
    df : pd.DataFrame
        The VSM being evaluated.        
    distfunc : function mapping vector pairs to floats.
        The measure of distance between vectors. Can also be 
        `vsm.euclidean`, `vsm.matching`, `vsm.jaccard`, as well as 
        any other float-valued function on pairs of vectors.    
        
    Raises
    ------
    ValueError
        If `df.index` is not a subset of the words in `reader`.
    
    Returns
    -------
    float, data
        `float` is the Spearman rank correlation coefficient between 
        the dataset scores and the similarity values obtained from 
        `df` using  `distfunc`. This evaluation is sensitive only to 
        rankings, not to absolute values.  `data` is a `pd.DataFrame` 
        with columns['word1', 'word2', 'score', 'distance'].
        
    """
    data = []
    for w1, w2, score in reader():
        d = {'word1': w1, 'word2': w2, 'score': score}
        for w in [w1, w2]:
            if w not in df.index:
                raise ValueError(
                    "Word '{}' is in the similarity dataset {} but not in the "
                    "DataFrame, making this evaluation ill-defined. Please "
                    "switch to a DataFrame with an appropriate vocabulary.".
                    format(w, get_reader_name(reader))) 
        d['distance'] = distfunc(df.loc[w1], df.loc[w2])
        data.append(d)
    data = pd.DataFrame(data)
    rho, pvalue = spearmanr(data['score'].values, data['distance'].values)
    return rho, data

In [17]:
rho, eval_df = word_similarity_evaluation(men_reader, giga5)

In [18]:
rho

0.40375964105441753

In [19]:
eval_df.head()

Unnamed: 0,word1,word2,score,distance
0,sun,sunlight,-50.0,0.956828
1,automobile,car,-50.0,0.979143
2,river,water,-49.0,0.970105
3,stairs,staircase,-49.0,0.980475
4,morning,sunrise,-49.0,0.963624


### Dataset error analysis

For error analysis, we can look at the words with the largest delta between the gold score and the distance value in our VSM. We do these comparisons based on ranks, just as with our primary metric (Spearman $\rho$), and we normalize both rankings so that they have a comparable number of levels.

In [20]:
def word_similarity_error_analysis(eval_df):    
    eval_df['distance_rank'] = _normalized_ranking(eval_df['distance'])
    eval_df['score_rank'] = _normalized_ranking(eval_df['score'])
    eval_df['error'] =  abs(eval_df['distance_rank'] - eval_df['score_rank'])
    return eval_df.sort_values('error')
    
    
def _normalized_ranking(series):
    ranks = series.rank(method='dense')
    return ranks / ranks.sum()    

Best predictions:

In [21]:
word_similarity_error_analysis(eval_df).head()

Unnamed: 0,word1,word2,score,distance,distance_rank,score_rank,error
1041,hummingbird,pelican,-32.0,0.975007,0.000243,0.000244,2.434543e-07
2315,lily,pigs,-13.0,0.980834,0.000488,0.000487,4.016842e-07
2951,bucket,girls,-4.0,0.983473,0.000602,0.000603,4.151568e-07
150,night,sunset,-43.0,0.96869,0.000102,0.000103,6.520315e-07
2062,oak,petals,-17.0,0.979721,0.000435,0.000436,7.162632e-07


Worst predictions:

In [22]:
word_similarity_error_analysis(eval_df).tail()

Unnamed: 0,word1,word2,score,distance,distance_rank,score_rank,error
67,branch,twigs,-45.0,0.984622,0.00063,7.7e-05,0.000553
190,birds,stork,-43.0,0.987704,0.000657,0.000103,0.000554
185,bloom,tulip,-43.0,0.990993,0.000663,0.000103,0.000561
167,bloom,blossom,-43.0,0.99176,0.000664,0.000103,0.000561
198,bloom,rose,-43.0,0.992406,0.000664,0.000103,0.000561


### Full evaluation

A full evaluation is just a loop over all the readers on which one want to evaluate, with a macro-average at the end:

In [23]:
def full_word_similarity_evaluation(df, readers=READERS, distfunc=vsm.cosine):
    """Evaluate a VSM against all datasets in `readers`.
    
    Parameters
    ----------
    df : pd.DataFrame
    readers : tuple 
        The similarity dataset readers on which to evaluate.
    distfunc : function mapping vector pairs to floats.
        The measure of distance between vectors. Can also be 
        `vsm.euclidean`, `vsm.matching`, `vsm.jaccard`, as well as 
        any other float-valued function on pairs of vectors.    
    
    Returns
    -------
    pd.Series
        Mapping dataset names to Spearman r values.
        
    """        
    scores = {}     
    for reader in readers:
        score, data_df = word_similarity_evaluation(reader, df, distfunc=distfunc)
        scores[get_reader_name(reader)] = score
    series = pd.Series(scores, name='Spearman r')
    series['Macro-average'] = series.mean()
    return series

In [24]:
if 'IS_GRADESCOPE_ENV' not in os.environ:
    display(full_word_similarity_evaluation(giga5))

wordsim353         0.327831
mturk771           0.143146
simverb3500dev    -0.065020
simverb3500test   -0.066314
men                0.403760
Macro-average      0.148681
Name: Spearman r, dtype: float64

## Homework questions

Please embed your homework responses in this notebook, and do not delete any cells from the notebook. (You are free to add as many cells as you like as part of your responses.)

### PPMI as a baseline [0.5 points]

The insight behind PPMI is a recurring theme in word representation learning, so it is a natural baseline for our task. For this question, write a function called `run_giga_ppmi_baseline` that does the following:

1. Reads the Gigaword count matrix with a window of 20 and a flat scaling function into a `pd.DataFrame`s, as is done in the VSM notebooks. The file is `data/vsmdata/giga_window20-flat.csv.gz`, and the VSM notebooks provide examples of the needed code.

1. Reweights this count matrix with PPMI.

1. Evaluates this reweighted matrix using `full_word_similarity_evaluation`. The return value of `run_giga_ppmi_baseline` should be the return value of this call to `full_word_similarity_evaluation`.

The goal of this question is to help you get more familiar with the code in `vsm` and the function `full_word_similarity_evaluation`.

The function `test_run_giga_ppmi_baseline` can be used to test that you've implemented this specification correctly.

In [25]:
def run_giga_ppmi_baseline():
    
    ##### YOUR CODE HERE
    giga20 = pd.read_csv(
        os.path.join(VSM_HOME, "giga_window20-flat.csv.gz"), index_col=0)
    #import pdb; pdb.set_trace()
    giga20_ppmi = vsm.pmi(giga20, positive=True)
    reweighted = giga20_ppmi
    # def full_word_similarity_evaluation(df, readers=READERS, distfunc=vsm.cosine)
    result = full_word_similarity_evaluation(reweighted)
    ws_result = result.loc['wordsim353'].round(2)
    #print(ws_result)
    return result
    """print(giga20.loc['wordsim353'])
    return giga20
    arr = giga20
    arr = giga20.to_numpy()
    #row_sum = arr.sum(axis=1)
    #col_sum = arr.sum(axis=0)
    rowsum = np.sum(arr, axis=1)
    colsum = np.sum(arr, axis=0)
    sumX = np.sum(rowsum, axis=0)
    expected = np.outer(rowsum, colsum)/sumX
    oe = arr/expected
    oe[oe == 0] = 1
    pmi = np.log(oe)
    #ppmi[ppmi < 0] = 0
    #giga20[:]=pmi
    return giga20"""




In [26]:
def test_run_giga_ppmi_baseline(run_giga_ppmi_baseline):
    result = run_giga_ppmi_baseline()
    ws_result = result.loc['wordsim353'].round(2)
    ws_expected = 0.58
    assert ws_result == ws_expected, \
        "Expected wordsim353 value of {}; got {}".format(ws_expected, ws_result)

In [27]:
if 'IS_GRADESCOPE_ENV' not in os.environ:
    test_run_giga_ppmi_baseline(run_giga_ppmi_baseline)

### Gigaword with LSA at different dimensions [0.5 points]

We might expect PPMI and LSA to form a solid pipeline that combines the strengths of PPMI with those of dimensionality reduction. However, LSA has a hyper-parameter $k$ – the dimensionality of the final representations – that will impact performance. For this problem, write a wrapper function `run_ppmi_lsa_pipeline` that does the following:

1. Takes as input a count `pd.DataFrame` and an LSA parameter `k`.
1. Reweights the count matrix with PPMI.
1. Applies LSA with dimensionality `k`.
1. Evaluates this reweighted matrix using `full_word_similarity_evaluation`. The return value of `run_ppmi_lsa_pipeline` should be the return value of this call to `full_word_similarity_evaluation`.

The goal of this question is to help you get a feel for how much LSA alone can contribute to this problem. 

The  function `test_run_ppmi_lsa_pipeline` will test your function on the count matrix in `data/vsmdata/giga_window20-flat.csv.gz`.

In [28]:
def run_ppmi_lsa_pipeline(count_df, k):
    
    ##### YOUR CODE HERE
    count_df_ppmi = vsm.pmi(count_df, positive=True)
    count_df_lsa = vsm.lsa(count_df_ppmi, k)
    result = full_word_similarity_evaluation(count_df_lsa)
    return result
     




In [29]:
def test_run_ppmi_lsa_pipeline(run_ppmi_lsa_pipeline):
    giga20 = pd.read_csv(
        os.path.join(VSM_HOME, "giga_window20-flat.csv.gz"), index_col=0)
    results = run_ppmi_lsa_pipeline(giga20, k=10)
    men_expected = 0.57
    men_result = results.loc['men'].round(2)
    assert men_result == men_expected,\
        "Expected men value of {}; got {}".format(men_expected, men_result)

In [30]:
if 'IS_GRADESCOPE_ENV' not in os.environ:
    test_run_ppmi_lsa_pipeline(run_ppmi_lsa_pipeline)

### Gigaword with GloVe for a small number of iterations [0.5 points]

Ideally, we would run GloVe for a very large number of iterations on a GPU machine to compare it against its close cousin PMI. However, we don't want this homework to cost you a lot of money or monopolize a lot of your available computing resources, so let's instead just probe GloVe a little bit to see if it has promise for our task. For this problem, write a function `run_small_glove_evals` that does the following:

1. Reads in `data/vsmdata/giga_window20-flat.csv.gz`.
1. Runs GloVe for 10, 100, and 200 iterations on `data/vsmdata/giga_window20-flat.csv.gz`, using the `mittens` implementation of `GloVe`. 
  * For all the other parameters to `mittens.GloVe` besides `max_iter`, use the package's defaults.
  * Because of the way that implementation is designed, these will have to be separate runs, but they should be relatively quick. 
1. Stores the values in a `dict` mapping each `max_iter` value to its associated 'Macro-average' score according to `full_word_similarity_evaluation`. `run_small_glove_evals`  should return this `dict`.

The trend should give you a sense for whether it is worth running GloVe for more iterations.

Some implementation notes:

* Your trained GloVe matrix `X` needs to be wrapped in a `pd.DataFrame` to work with `full_word_similarity_evaluation`. `pd.DataFrame(X, index=giga20.index)` will do the trick.

* If `glv` is your GloVe model, then running `glv.sess.close()` after each model is trained will silence warnings from TensorFlow about interactive sessions being active.

Performance will vary a lot for this function, so there is some uncertainty in the testing, but `test_run_small_glove_evals` will at least check that you wrote a function with the right general logic.

In [31]:
def run_small_glove_evals():

    from mittens import GloVe
    
    ##### YOUR CODE HERE
    giga20 = pd.read_csv(
        os.path.join(VSM_HOME, "giga_window20-flat.csv.gz"), index_col=0)
    result = {}
    iter_list = [10, 100, 200]
    #iter_list = [10] 
    for iters in iter_list:
        glv = GloVe(max_iter=iters)
        # to_numpy
        X = glv.fit(giga20.to_numpy())
        #glv.sess.close()
        df = pd.DataFrame(X, index=giga20.index)
        evaluation = full_word_similarity_evaluation(df)
        #print(evaluation)
        result[iters] = evaluation['Macro-average']
    return result



In [32]:
def test_run_small_glove_evals(run_small_glove_evals):
    data = run_small_glove_evals()
    for max_iter in (10, 100, 200):
        assert max_iter in data
        assert isinstance(data[max_iter], float)

In [33]:
if 'IS_GRADESCOPE_ENV' not in os.environ:
    test_run_small_glove_evals(run_small_glove_evals)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Iteration 200: loss: 1968702.255

### Dice coefficient [0.5 points]

Implement the Dice coefficient for real-valued vectors, as

$$
\textbf{dice}(u, v) = 
1 - \frac{
  2 \sum_{i=1}^{n}\min(u_{i}, v_{i})
}{
    \sum_{i=1}^{n} u_{i} + v_{i}
}$$
 
You can use `test_dice_implementation` below to check that your implementation is correct.

In [34]:
def test_dice_implementation(func):
    """`func` should be an implementation of `dice` as defined above."""
    X = np.array([
        [  4.,   4.,   2.,   0.],
        [  4.,  61.,   8.,  18.],
        [  2.,   8.,  10.,   0.],
        [  0.,  18.,   0.,   5.]]) 
    assert func(X[0], X[1]).round(5) == 0.80198
    assert func(X[1], X[2]).round(5) == 0.67568

In [35]:
def dice(u, v):

    
    ##### YOUR CODE HERE
    denom = np.sum(u)+np.sum(v)
    num = 0.0
    len_u = len(u)
    for i in range(len_u):
        i = int(i)
        #print(u[i], v[i])
        min_i = min(u[i],v[i])
        num = num + min_i
    
    num *= 2.0
    coeff = 1 - (num/denom)
    return coeff
    



In [36]:
if 'IS_GRADESCOPE_ENV' not in os.environ:
    test_dice_implementation(dice)

### t-test reweighting [2 points]



The t-test statistic can be thought of as a reweighting scheme. For a count matrix $X$, row index $i$, and column index $j$:

$$\textbf{ttest}(X, i, j) = 
\frac{
    P(X, i, j) - \big(P(X, i, *)P(X, *, j)\big)
}{
\sqrt{(P(X, i, *)P(X, *, j))}
}$$

where $P(X, i, j)$ is $X_{ij}$ divided by the total values in $X$, $P(X, i, *)$ is the sum of the values in row $i$ of $X$ divided by the total values in $X$, and $P(X, *, j)$ is the sum of the values in column $j$ of $X$ divided by the total values in $X$.

For this problem, implement this reweighting scheme. You can use `test_ttest_implementation` below to check that your implementation is correct. You do not need to use this for any evaluations, though we hope you will be curious enough to do so!

In [37]:
def test_ttest_implementation(func):
    """`func` should be an implementation of t-test reweighting as 
    defined above.
    """
    X = pd.DataFrame(np.array([
        [  4.,   4.,   2.,   0.],
        [  4.,  61.,   8.,  18.],
        [  2.,   8.,  10.,   0.],
        [  0.,  18.,   0.,   5.]]))    
    actual = np.array([
        [ 0.33056, -0.07689,  0.04321, -0.10532],
        [-0.07689,  0.03839, -0.10874,  0.07574],
        [ 0.04321, -0.10874,  0.36111, -0.14894],
        [-0.10532,  0.07574, -0.14894,  0.05767]])    
    predicted = func(X)
    assert np.array_equal(predicted.round(5), actual)

In [38]:
def ttest(df):
    
    ##### YOUR CODE HERE
    tot = df.to_numpy().sum()
    df = df/tot
    rowsum = df.sum(axis=1)
    colsum = df.sum(axis=0)
    #tot = colsum.sum()
    #tot = df.sum().sum()
    #df = df/tot
    op = np.outer(rowsum, colsum)
    numer = df - op
    denom = op**0.5
    val = numer/denom
    print(val)
    return val


In [39]:
if 'IS_GRADESCOPE_ENV' not in os.environ:
    test_ttest_implementation(ttest)

          0         1         2         3
0  0.330556 -0.076889  0.043212 -0.105318
1 -0.076889  0.038385 -0.108737  0.075745
2  0.043212 -0.108737  0.361111 -0.148942
3 -0.105318  0.075745 -0.148942  0.057669


### Enriching a VSM with subword information [2 points]

It might be useful to combine character-level information with word-level information. To help you begin asssessing this idea, this question asks you to write a function that modifies an existing VSM so that the representation for each word $w$ is the element-wise sum of $w$'s original word-level representation with all the representations for the n-grams $w$ contains. 

The following starter code should help you structure this and clarify the requirements, and a simple test is included below as well.

You don't need to write a lot of code; the motivation for this question is that the function you write could have practical value.

In [32]:
def subword_enrichment(df, n=4):
    
    # 1. Use `vsm.ngram_vsm` to create a character-level 
    # VSM from `df`, using the above parameter `n` to 
    # set the size of the ngrams.
    
    ##### YOUR CODE HERE
    ngram_vsm = vsm.ngram_vsm(df, n)

        
    # 2. Use `vsm.character_level_rep` to get the representation
    # for every word in `df` according to the character-level
    # VSM you created above.
    
    ##### YOUR CODE HERE
    row_count = df.shape[0]
    #char_rep_np = np.zeros(df.shape)
    char_rep_np = []
    df_np = df.to_numpy()
    for i in range(row_count):
        #char_rep_np[i][:] = vsm.character_level_rep( df.iloc[i], ngram_vsm, n)
        #char_rep = vsm.character_level_rep( df.iloc[i], ngram_vsm, n)
        #print(df.iloc[i])
                         # def character_level_rep(word, cf, n=4):
        char_rep_np.append(vsm.character_level_rep(df.index[i], ngram_vsm, n))
    char_rep_np = np.array(char_rep_np)
    df_char = pd.DataFrame(char_rep_np, index=df.index)
    # 3. For each representation created at step 2, add in its
    # original representation from `df`. (This should use
    # element-wise addition; the dimensionality of the vectors
    # will be unchanged.)
                            
    ##### YOUR CODE HERE

    df_agg = df_char + df
    
    # 4. Return a `pd.DataFrame` with the same index and column
    # values as `df`, but filled with the new representations
    # created at step 3.
                            
    ##### YOUR CODE HERE
    #print("printing df_agg")
    #print(df_agg)
    return df_agg



In [33]:
def test_subword_enrichment(func):
    """`func` should be an implementation of subword_enrichment as 
    defined above.
    """
    vocab = ["ABCD", "BCDA", "CDAB", "DABC"]
    df = pd.DataFrame([
        [1, 1, 2, 1],
        [3, 4, 2, 4],
        [0, 0, 1, 0],
        [1, 0, 0, 0]], index=vocab)
    expected = pd.DataFrame([
        [14, 14, 18, 14],
        [22, 26, 18, 26],
        [10, 10, 14, 10],
        [14, 10, 10, 10]], index=vocab)
    new_df = func(df, n=2)
    assert np.array_equal(expected.columns, new_df.columns), \
        "Columns are not the same"
    assert np.array_equal(expected.index, new_df.index), \
        "Indices are not the same"
    assert np.array_equal(expected.values, new_df.values), \
        "Co-occurrence values aren't the same"    

In [34]:
if 'IS_GRADESCOPE_ENV' not in os.environ:
    test_subword_enrichment(subword_enrichment)

### Your original system [3 points]

This question asks you to design your own model. You can of course include steps made above (ideally, the above questions informed your system design!), but your model should not be literally identical to any of the above models. Other ideas: retrofitting, autoencoders, GloVe, subword modeling, ... 

Requirements:

1. Your code must operate on one of the count matrices in `data/vsmdata`. You can choose which one. __Other pretrained vectors cannot be introduced__.

1. Your code must be self-contained, so that we can work with your model directly in your homework submission notebook. If your model depends on external data or other resources, please submit a ZIP archive containing these resources along with your submission.

In the cell below, please provide a brief technical description of your original system, so that the teaching team can gain an understanding of what it does. This will help us to understand your code and analyze all the submissions to identify patterns and strategies.

In [35]:
# Enter your system description in this cell.
# Please do not remove this comment.
if 'IS_GRADESCOPE_ENV' not in os.environ:
    pass

    """
    The system builds of from the pipeline shown in the class
    i.e. ppmi->lsa->ae
    But the system uses MMD-VAE which tries to constrains 
    the latent representation to unit Gaussian.
    I believe the advantage is that it should be robust to 
    different distance measures.
    I also tried using ppmi->lsa->subword enrichment, it gave
    poor results e.g:
                    finance    0.000000
                    balance    0.004174
                    chance     0.006348
                    advance    0.006587
                    glance     0.006653
    However a linear combination of ppmi and enrichment looked promising so I did:
    enriched_vsm = (w1*dataset_ppmi_lsa_k + w2*enriched_model/(w1+w2)


    """

## Bake-off [1 point]

For the bake-off, we will release two additional datasets. The announcement will go out on the discussion forum. We will also release reader code for these datasets that you can paste into this notebook. You will evaluate your custom model $M$ (from the previous question) on these new datasets using `full_word_similarity_evaluation`. Rules:

1. Only one evaluation is permitted.
1. No additional system tuning is permitted once the bake-off has started.

The cells below this one constitute your bake-off entry.

People who enter will receive the additional homework point, and people whose systems achieve the top score will receive an additional 0.5 points. We will test the top-performing systems ourselves, and only systems for which we can reproduce the reported results will win the extra 0.5 points.

Late entries will be accepted, but they cannot earn the extra 0.5 points. Similarly, you cannot win the bake-off unless your homework is submitted on time.

The announcement will include the details on where to submit your entry.

In [36]:
# Enter your bake-off assessment code into this cell. 
# Please do not remove this comment.

if 'IS_GRADESCOPE_ENV' not in os.environ:
    pass
    # Please enter your code in the scope of the above conditional.
    ##### YOUR CODE HERE
    
    from torch_autoencoder_mmd_info import TorchAutoencoderMMDIG
    model = TorchAutoencoderMMDIG()
    
    
    
    dataset_filename = "giga_window20-flat.csv.gz"
    k_dims = 100
    test_word = "finance"
    use_subword_enrichment = True
    
    dataset = pd.read_csv(
        os.path.join(VSM_HOME, dataset_filename), index_col=0)
    np_arr = dataset.to_numpy()
    print(len(np_arr))
    neighbours_raw = vsm.neighbors(test_word, dataset).head()
    print("neighbours_raw")
    print(neighbours_raw)
    
 
    dataset_ppmi = vsm.pmi(dataset, positive=True)
    #dataset_ppmi = vsm.pmi(enriched_model, positive=True)
    neighbours_ppmi = vsm.neighbors(test_word, dataset_ppmi).head()
    print("neighbours_ppmi")
    print(neighbours_ppmi)
    
    
    dataset_ppmi_lsa_k = vsm.lsa(dataset_ppmi, k=k_dims)
    neighbours_ppmi_lsa_k = vsm.neighbors(test_word, dataset_ppmi_lsa_k).head()
    print("neighbours_ppmi_lsa_k")
    print(neighbours_ppmi_lsa_k)
    enriched_neighbours_ppmi_lsa_k = dataset_ppmi_lsa_k
    if use_subword_enrichment:
        enriched_model = subword_enrichment(dataset_ppmi_lsa_k, n=4)
        w1 = 1
        w2 = 0.01
        w = w1+w2
        enriched_neighbours_ppmi_lsa_k = (w1*dataset_ppmi_lsa_k.to_numpy() + w2*enriched_model.to_numpy())/(w)
        enriched_neighbours_ppmi_lsa_k = pd.DataFrame(enriched_neighbours_ppmi_lsa_k, index=dataset.index)
        neighbours_enriched = vsm.neighbors(test_word, enriched_neighbours_ppmi_lsa_k).head()
        print("neighbours_enriched")
        print(neighbours_enriched)
        #dataset_ppmi_lsa_k = enriched_neighbours_ppmi_lsa_k
        
    
    
    
    #model.fit(neighbours_ppmi_lsa_k)
    


5000
neighbours_raw
finance    0.000000
.          0.132519
</p>       0.139872
<p>        0.141332
said       0.158305
dtype: float64
neighbours_ppmi
finance     0.000000
banking     0.363871
monetary    0.404030
loans       0.413165
banks       0.415815
dtype: float64
neighbours_ppmi_lsa_k
finance     0.000000
banking     0.136505
monetary    0.167089
debt        0.171278
loans       0.172590
dtype: float64
neighbours_enriched
finance        0.000000
banking        0.081684
investment     0.087688
investments    0.117994
management     0.121275
dtype: float64


In [38]:
from torch_autoencoder_mmd_info import TorchAutoencoderMMDIG
model = TorchAutoencoderMMDIG()
df_train = pd.DataFrame(enriched_neighbours_ppmi_lsa_k, index=dataset.index)
enriched_neighbours_ppmi_lsa_k_ae = model.fit(df_train)
enriched_neighbours_ppmi_lsa_k_ae.to_pickle("enriched_neighbours_ppmi_lsa_k_ae")
neighbours_enriched_ae = vsm.neighbors(test_word, enriched_neighbours_ppmi_lsa_k_ae).head()
print("neighbours_enriched")
print(neighbours_enriched_ae)

recons tensor(11.9262, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0103, grad_fn=<SubBackward0>)
epoch_error 11.936442255973816
recons tensor(11.7087, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0100, grad_fn=<SubBackward0>)
epoch_error 23.655068516731262
recons tensor(11.1627, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0108, grad_fn=<SubBackward0>)
epoch_error 34.82855975627899
recons tensor(11.1363, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0120, grad_fn=<SubBackward0>)
epoch_error 45.976874351501465
recons tensor(10.1972, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0129, grad_fn=<SubBackward0>)
epoch_error 56.18700158596039


Finished epoch 1 of 100; error is 10.210126876831055

recons tensor(10.3068, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0145, grad_fn=<SubBackward0>)
epoch_error 10.32129716873169
recons tensor(9.3004, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0153, grad_fn=<SubBackward0>)
epoch_error 19.636953711509705
recons tensor(9.8056, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0164, grad_fn=<SubBackward0>)
epoch_error 29.459035277366638
recons tensor(9.6190, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0164, grad_fn=<SubBackward0>)
epoch_error 39.09436595439911
recons tensor(9.3844, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0176, grad_fn=<SubBackward0>)
epoch_error 48.49633526802063


Finished epoch 2 of 100; error is 9.401968955993652

recons tensor(9.1140, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0173, grad_fn=<SubBackward0>)
epoch_error 9.13133704662323
recons tensor(8.8895, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0184, grad_fn=<SubBackward0>)
epoch_error 18.03919005393982
recons tensor(8.5985, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0185, grad_fn=<SubBackward0>)
epoch_error 26.65614879131317
recons tensor(8.5323, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0187, grad_fn=<SubBackward0>)
epoch_error 35.20718538761139
recons tensor(8.5169, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0190, grad_fn=<SubBackward0>)
epoch_error 43.74307990074158


Finished epoch 3 of 100; error is 8.535894393920898

recons tensor(7.9438, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0192, grad_fn=<SubBackward0>)
epoch_error 7.962967395782471
recons tensor(8.1760, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0191, grad_fn=<SubBackward0>)
epoch_error 16.158053755760193
recons tensor(7.4233, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0189, grad_fn=<SubBackward0>)
epoch_error 23.60018765926361
recons tensor(8.4802, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0186, grad_fn=<SubBackward0>)
epoch_error 32.09900605678558
recons tensor(7.8371, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0189, grad_fn=<SubBackward0>)
epoch_error 39.9549640417099


Finished epoch 4 of 100; error is 7.855957984924316

recons tensor(7.6045, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0187, grad_fn=<SubBackward0>)
epoch_error 7.623134136199951
recons tensor(7.5323, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0181, grad_fn=<SubBackward0>)
epoch_error 15.173454523086548
recons tensor(7.4494, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0183, grad_fn=<SubBackward0>)
epoch_error 22.641180634498596
recons tensor(7.1642, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0176, grad_fn=<SubBackward0>)
epoch_error 29.822944402694702
recons tensor(6.7834, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0175, grad_fn=<SubBackward0>)
epoch_error 36.62385964393616


Finished epoch 5 of 100; error is 6.800915241241455

recons tensor(6.7639, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0172, grad_fn=<SubBackward0>)
epoch_error 6.7811057567596436
recons tensor(6.6815, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0168, grad_fn=<SubBackward0>)
epoch_error 13.479470491409302
recons tensor(6.7553, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0170, grad_fn=<SubBackward0>)
epoch_error 20.251789569854736
recons tensor(6.6167, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0168, grad_fn=<SubBackward0>)
epoch_error 26.885268449783325
recons tensor(6.8144, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0165, grad_fn=<SubBackward0>)
epoch_error 33.71613931655884


Finished epoch 6 of 100; error is 6.830870628356934

recons tensor(6.2232, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0160, grad_fn=<SubBackward0>)
epoch_error 6.239180088043213
recons tensor(6.6597, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0168, grad_fn=<SubBackward0>)
epoch_error 12.915623188018799
recons tensor(6.3146, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0165, grad_fn=<SubBackward0>)
epoch_error 19.246809124946594
recons tensor(6.0211, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0159, grad_fn=<SubBackward0>)
epoch_error 25.283835649490356
recons tensor(5.7005, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0159, grad_fn=<SubBackward0>)
epoch_error 31.000243544578552


Finished epoch 7 of 100; error is 5.716407775878906

recons tensor(6.0022, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0163, grad_fn=<SubBackward0>)
epoch_error 6.018530964851379
recons tensor(6.2551, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0161, grad_fn=<SubBackward0>)
epoch_error 12.289755463600159
recons tensor(5.4570, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0153, grad_fn=<SubBackward0>)
epoch_error 17.762007474899292
recons tensor(5.4085, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0154, grad_fn=<SubBackward0>)
epoch_error 23.185935378074646
recons tensor(5.4794, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0154, grad_fn=<SubBackward0>)
epoch_error 28.680777430534363


Finished epoch 8 of 100; error is 5.494842052459717

recons tensor(5.1918, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0151, grad_fn=<SubBackward0>)
epoch_error 5.206884980201721
recons tensor(5.2070, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0149, grad_fn=<SubBackward0>)
epoch_error 10.428728699684143
recons tensor(5.6758, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0155, grad_fn=<SubBackward0>)
epoch_error 16.120039343833923
recons tensor(5.3684, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0152, grad_fn=<SubBackward0>)
epoch_error 21.503646969795227
recons tensor(5.0743, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0149, grad_fn=<SubBackward0>)
epoch_error 26.59287393093109


Finished epoch 9 of 100; error is 5.089226722717285

recons tensor(5.0639, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0146, grad_fn=<SubBackward0>)
epoch_error 5.078501105308533
recons tensor(4.9213, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0145, grad_fn=<SubBackward0>)
epoch_error 10.014325141906738
recons tensor(5.0797, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0145, grad_fn=<SubBackward0>)
epoch_error 15.108578443527222
recons tensor(4.9177, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0145, grad_fn=<SubBackward0>)
epoch_error 20.04083776473999
recons tensor(4.7358, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0141, grad_fn=<SubBackward0>)
epoch_error 24.790724992752075


Finished epoch 10 of 100; error is 4.749887466430664

recons tensor(4.8510, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0139, grad_fn=<SubBackward0>)
epoch_error 4.864915132522583
recons tensor(4.7787, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0138, grad_fn=<SubBackward0>)
epoch_error 9.657413244247437
recons tensor(4.8080, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0140, grad_fn=<SubBackward0>)
epoch_error 14.479337692260742
recons tensor(4.4977, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0134, grad_fn=<SubBackward0>)
epoch_error 18.990426182746887
recons tensor(4.1652, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0136, grad_fn=<SubBackward0>)
epoch_error 23.169200658798218


Finished epoch 11 of 100; error is 4.178774356842041

recons tensor(4.5027, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0134, grad_fn=<SubBackward0>)
epoch_error 4.5161226987838745
recons tensor(4.4913, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0135, grad_fn=<SubBackward0>)
epoch_error 9.020899176597595
recons tensor(4.2103, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0128, grad_fn=<SubBackward0>)
epoch_error 13.243986248970032
recons tensor(4.4066, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0130, grad_fn=<SubBackward0>)
epoch_error 17.66357386112213
recons tensor(4.1121, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0133, grad_fn=<SubBackward0>)
epoch_error 21.78895127773285


Finished epoch 12 of 100; error is 4.125377655029297

recons tensor(4.3929, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0133, grad_fn=<SubBackward0>)
epoch_error 4.406165242195129
recons tensor(4.0373, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0122, grad_fn=<SubBackward0>)
epoch_error 8.45566737651825
recons tensor(4.0982, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0130, grad_fn=<SubBackward0>)
epoch_error 12.566866397857666
recons tensor(4.0754, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0125, grad_fn=<SubBackward0>)
epoch_error 16.65482723712921
recons tensor(3.8888, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0130, grad_fn=<SubBackward0>)
epoch_error 20.556661009788513


Finished epoch 13 of 100; error is 3.9018337726593018

recons tensor(3.9771, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0126, grad_fn=<SubBackward0>)
epoch_error 3.9896700382232666
recons tensor(3.9170, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0126, grad_fn=<SubBackward0>)
epoch_error 7.91924774646759
recons tensor(3.9205, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0121, grad_fn=<SubBackward0>)
epoch_error 11.851831793785095
recons tensor(3.6845, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0118, grad_fn=<SubBackward0>)
epoch_error 15.548096537590027
recons tensor(3.9206, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0118, grad_fn=<SubBackward0>)
epoch_error 19.48050057888031


Finished epoch 14 of 100; error is 3.932404041290283

recons tensor(3.8083, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0120, grad_fn=<SubBackward0>)
epoch_error 3.8203046321868896
recons tensor(3.7842, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0122, grad_fn=<SubBackward0>)
epoch_error 7.616694688796997
recons tensor(3.6703, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0118, grad_fn=<SubBackward0>)
epoch_error 11.298725962638855
recons tensor(3.4918, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0113, grad_fn=<SubBackward0>)
epoch_error 14.801891922950745
recons tensor(3.6746, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0116, grad_fn=<SubBackward0>)
epoch_error 18.488067746162415


Finished epoch 15 of 100; error is 3.68617582321167

recons tensor(3.4432, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0114, grad_fn=<SubBackward0>)
epoch_error 3.454654574394226
recons tensor(3.6075, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0116, grad_fn=<SubBackward0>)
epoch_error 7.07377028465271
recons tensor(3.6388, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0114, grad_fn=<SubBackward0>)
epoch_error 10.723888278007507
recons tensor(3.4008, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0112, grad_fn=<SubBackward0>)
epoch_error 14.135921478271484
recons tensor(3.4293, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0115, grad_fn=<SubBackward0>)
epoch_error 17.57671868801117


Finished epoch 16 of 100; error is 3.4407973289489746

recons tensor(3.2878, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0109, grad_fn=<SubBackward0>)
epoch_error 3.2987223863601685
recons tensor(3.3697, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0108, grad_fn=<SubBackward0>)
epoch_error 6.679304599761963
recons tensor(3.3831, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0106, grad_fn=<SubBackward0>)
epoch_error 10.073023438453674
recons tensor(3.2886, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0108, grad_fn=<SubBackward0>)
epoch_error 13.372417688369751
recons tensor(3.3728, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0112, grad_fn=<SubBackward0>)
epoch_error 16.756437063217163


Finished epoch 17 of 100; error is 3.384019374847412

recons tensor(3.3893, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0105, grad_fn=<SubBackward0>)
epoch_error 3.3997963666915894
recons tensor(3.1528, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0106, grad_fn=<SubBackward0>)
epoch_error 6.563149809837341
recons tensor(3.0083, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0101, grad_fn=<SubBackward0>)
epoch_error 9.581554412841797
recons tensor(3.3580, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0101, grad_fn=<SubBackward0>)
epoch_error 12.94962465763092
recons tensor(3.0071, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0102, grad_fn=<SubBackward0>)
epoch_error 15.966999769210815


Finished epoch 18 of 100; error is 3.0173749923706055

recons tensor(2.9758, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0102, grad_fn=<SubBackward0>)
epoch_error 2.9860156774520874
recons tensor(3.2070, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0104, grad_fn=<SubBackward0>)
epoch_error 6.20341157913208
recons tensor(2.9375, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0097, grad_fn=<SubBackward0>)
epoch_error 9.15064799785614
recons tensor(3.0344, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0101, grad_fn=<SubBackward0>)
epoch_error 12.1951664686203
recons tensor(3.0788, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0098, grad_fn=<SubBackward0>)
epoch_error 15.283780455589294


Finished epoch 19 of 100; error is 3.088613986968994

recons tensor(2.9270, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0099, grad_fn=<SubBackward0>)
epoch_error 2.936958909034729
recons tensor(2.8066, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0094, grad_fn=<SubBackward0>)
epoch_error 5.752964735031128
recons tensor(2.7775, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0097, grad_fn=<SubBackward0>)
epoch_error 8.540242671966553
recons tensor(3.1418, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0100, grad_fn=<SubBackward0>)
epoch_error 11.692044734954834
recons tensor(2.9172, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0096, grad_fn=<SubBackward0>)
epoch_error 14.618896245956421


Finished epoch 20 of 100; error is 2.926851511001587

recons tensor(2.8602, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0096, grad_fn=<SubBackward0>)
epoch_error 2.8697785139083862
recons tensor(3.0297, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0089, grad_fn=<SubBackward0>)
epoch_error 5.908384084701538
recons tensor(2.6773, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0094, grad_fn=<SubBackward0>)
epoch_error 8.595141172409058
recons tensor(2.7091, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0094, grad_fn=<SubBackward0>)
epoch_error 11.313728332519531
recons tensor(2.6775, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0091, grad_fn=<SubBackward0>)
epoch_error 14.000323295593262


Finished epoch 21 of 100; error is 2.6865949630737305

recons tensor(2.8025, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0094, grad_fn=<SubBackward0>)
epoch_error 2.811857223510742
recons tensor(2.7105, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0091, grad_fn=<SubBackward0>)
epoch_error 5.5314812660217285
recons tensor(2.5911, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0087, grad_fn=<SubBackward0>)
epoch_error 8.13136100769043
recons tensor(2.6453, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0089, grad_fn=<SubBackward0>)
epoch_error 10.785549640655518
recons tensor(2.6515, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0087, grad_fn=<SubBackward0>)
epoch_error 13.445790410041809


Finished epoch 22 of 100; error is 2.660240650177002

recons tensor(2.7292, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0089, grad_fn=<SubBackward0>)
epoch_error 2.738085150718689
recons tensor(2.5130, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0086, grad_fn=<SubBackward0>)
epoch_error 5.2596595287323
recons tensor(2.4460, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0082, grad_fn=<SubBackward0>)
epoch_error 7.713876128196716
recons tensor(2.6141, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0084, grad_fn=<SubBackward0>)
epoch_error 10.336371541023254
recons tensor(2.5865, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0091, grad_fn=<SubBackward0>)
epoch_error 12.931969046592712


Finished epoch 23 of 100; error is 2.595597505569458

recons tensor(2.4767, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0083, grad_fn=<SubBackward0>)
epoch_error 2.4849987030029297
recons tensor(2.6759, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0086, grad_fn=<SubBackward0>)
epoch_error 5.169494032859802
recons tensor(2.4628, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0086, grad_fn=<SubBackward0>)
epoch_error 7.640852689743042
recons tensor(2.3932, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0082, grad_fn=<SubBackward0>)
epoch_error 10.042283415794373
recons tensor(2.3856, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0079, grad_fn=<SubBackward0>)
epoch_error 12.435718297958374


Finished epoch 24 of 100; error is 2.393435001373291

recons tensor(2.3064, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0079, grad_fn=<SubBackward0>)
epoch_error 2.3143683671951294
recons tensor(2.5712, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0084, grad_fn=<SubBackward0>)
epoch_error 4.893986821174622
recons tensor(2.4008, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0085, grad_fn=<SubBackward0>)
epoch_error 7.303307890892029
recons tensor(2.3998, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0078, grad_fn=<SubBackward0>)
epoch_error 9.710983633995056
recons tensor(2.2578, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0081, grad_fn=<SubBackward0>)
epoch_error 11.976892232894897


Finished epoch 25 of 100; error is 2.265908718109131

recons tensor(2.3139, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0080, grad_fn=<SubBackward0>)
epoch_error 2.3219199180603027
recons tensor(2.2576, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0080, grad_fn=<SubBackward0>)
epoch_error 4.587507843971252
recons tensor(2.3381, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0076, grad_fn=<SubBackward0>)
epoch_error 6.933156371116638
recons tensor(2.4388, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0083, grad_fn=<SubBackward0>)
epoch_error 9.380259275436401
recons tensor(2.1794, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0075, grad_fn=<SubBackward0>)
epoch_error 11.567179441452026


Finished epoch 26 of 100; error is 2.186920166015625

recons tensor(2.3541, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0077, grad_fn=<SubBackward0>)
epoch_error 2.361884832382202
recons tensor(2.2770, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0081, grad_fn=<SubBackward0>)
epoch_error 4.646984815597534
recons tensor(2.1527, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0077, grad_fn=<SubBackward0>)
epoch_error 6.807377338409424
recons tensor(2.2004, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0076, grad_fn=<SubBackward0>)
epoch_error 9.015391230583191
recons tensor(2.1853, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0073, grad_fn=<SubBackward0>)
epoch_error 11.207966208457947


Finished epoch 27 of 100; error is 2.192574977874756

recons tensor(2.2088, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0076, grad_fn=<SubBackward0>)
epoch_error 2.2164175510406494
recons tensor(2.1256, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0074, grad_fn=<SubBackward0>)
epoch_error 4.3494696617126465
recons tensor(2.1706, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0075, grad_fn=<SubBackward0>)
epoch_error 6.527547121047974
recons tensor(2.2180, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0073, grad_fn=<SubBackward0>)
epoch_error 8.752920031547546
recons tensor(2.0940, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0075, grad_fn=<SubBackward0>)
epoch_error 10.854400396347046


Finished epoch 28 of 100; error is 2.101480484008789

recons tensor(2.0444, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0072, grad_fn=<SubBackward0>)
epoch_error 2.0516010522842407
recons tensor(2.2391, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0071, grad_fn=<SubBackward0>)
epoch_error 4.297778010368347
recons tensor(2.1020, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0073, grad_fn=<SubBackward0>)
epoch_error 6.4071245193481445
recons tensor(2.0667, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0073, grad_fn=<SubBackward0>)
epoch_error 8.481110095977783
recons tensor(2.0541, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0069, grad_fn=<SubBackward0>)
epoch_error 10.542150855064392


Finished epoch 29 of 100; error is 2.0610408782958984

recons tensor(1.9931, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0072, grad_fn=<SubBackward0>)
epoch_error 2.0003217458724976
recons tensor(1.9630, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0071, grad_fn=<SubBackward0>)
epoch_error 3.9704023599624634
recons tensor(2.1363, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0074, grad_fn=<SubBackward0>)
epoch_error 6.114135384559631
recons tensor(2.1462, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0075, grad_fn=<SubBackward0>)
epoch_error 8.267815709114075
recons tensor(1.9727, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0067, grad_fn=<SubBackward0>)
epoch_error 10.247208952903748


Finished epoch 30 of 100; error is 1.9793932437896729

recons tensor(1.9903, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0073, grad_fn=<SubBackward0>)
epoch_error 1.9975813627243042
recons tensor(1.9376, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0065, grad_fn=<SubBackward0>)
epoch_error 3.9417253732681274
recons tensor(2.0725, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0074, grad_fn=<SubBackward0>)
epoch_error 6.021690011024475
recons tensor(1.9745, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0069, grad_fn=<SubBackward0>)
epoch_error 8.003108382225037
recons tensor(1.9764, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0067, grad_fn=<SubBackward0>)
epoch_error 9.986224055290222


Finished epoch 31 of 100; error is 1.9831156730651855

recons tensor(1.9370, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0072, grad_fn=<SubBackward0>)
epoch_error 1.944193720817566
recons tensor(2.0689, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0067, grad_fn=<SubBackward0>)
epoch_error 4.0197993516922
recons tensor(1.9456, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0066, grad_fn=<SubBackward0>)
epoch_error 5.972007155418396
recons tensor(1.7975, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0062, grad_fn=<SubBackward0>)
epoch_error 7.775627017021179
recons tensor(1.9645, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0067, grad_fn=<SubBackward0>)
epoch_error 9.74682903289795


Finished epoch 32 of 100; error is 1.97120201587677

recons tensor(1.9095, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0068, grad_fn=<SubBackward0>)
epoch_error 1.9162561893463135
recons tensor(1.9432, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0069, grad_fn=<SubBackward0>)
epoch_error 3.866315722465515
recons tensor(1.8901, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0065, grad_fn=<SubBackward0>)
epoch_error 5.7629475593566895
recons tensor(1.8626, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0065, grad_fn=<SubBackward0>)
epoch_error 7.632081508636475
recons tensor(1.8749, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0061, grad_fn=<SubBackward0>)
epoch_error 9.513075828552246


Finished epoch 33 of 100; error is 1.8809943199157715

recons tensor(1.8536, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0060, grad_fn=<SubBackward0>)
epoch_error 1.8596534729003906
recons tensor(1.8325, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0066, grad_fn=<SubBackward0>)
epoch_error 3.698737859725952
recons tensor(1.8703, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0065, grad_fn=<SubBackward0>)
epoch_error 5.575557589530945
recons tensor(1.8805, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0066, grad_fn=<SubBackward0>)
epoch_error 7.4626384973526
recons tensor(1.8185, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0066, grad_fn=<SubBackward0>)
epoch_error 9.287753343582153


Finished epoch 34 of 100; error is 1.8251148462295532

recons tensor(1.7881, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0065, grad_fn=<SubBackward0>)
epoch_error 1.7946542501449585
recons tensor(1.8507, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0060, grad_fn=<SubBackward0>)
epoch_error 3.6513900756835938
recons tensor(1.8217, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0065, grad_fn=<SubBackward0>)
epoch_error 5.479549407958984
recons tensor(1.8055, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0063, grad_fn=<SubBackward0>)
epoch_error 7.291338801383972
recons tensor(1.7921, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0065, grad_fn=<SubBackward0>)
epoch_error 9.089959621429443


Finished epoch 35 of 100; error is 1.7986208200454712

recons tensor(1.6927, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0062, grad_fn=<SubBackward0>)
epoch_error 1.6988606452941895
recons tensor(1.7916, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0062, grad_fn=<SubBackward0>)
epoch_error 3.496598720550537
recons tensor(1.8770, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0063, grad_fn=<SubBackward0>)
epoch_error 5.379880309104919
recons tensor(1.8129, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0067, grad_fn=<SubBackward0>)
epoch_error 7.19946563243866
recons tensor(1.7001, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0059, grad_fn=<SubBackward0>)
epoch_error 8.905460476875305


Finished epoch 36 of 100; error is 1.7059948444366455

recons tensor(1.8077, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0061, grad_fn=<SubBackward0>)
epoch_error 1.8137760162353516
recons tensor(1.7812, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0067, grad_fn=<SubBackward0>)
epoch_error 3.6016021966934204
recons tensor(1.7359, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0058, grad_fn=<SubBackward0>)
epoch_error 5.343290328979492
recons tensor(1.6558, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0057, grad_fn=<SubBackward0>)
epoch_error 7.004786252975464
recons tensor(1.7210, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0062, grad_fn=<SubBackward0>)
epoch_error 8.731975555419922


Finished epoch 37 of 100; error is 1.727189302444458

recons tensor(1.7922, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0062, grad_fn=<SubBackward0>)
epoch_error 1.798466682434082
recons tensor(1.7055, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0057, grad_fn=<SubBackward0>)
epoch_error 3.5096739530563354
recons tensor(1.7408, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0062, grad_fn=<SubBackward0>)
epoch_error 5.256661534309387
recons tensor(1.7000, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0060, grad_fn=<SubBackward0>)
epoch_error 6.962702631950378
recons tensor(1.5864, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0063, grad_fn=<SubBackward0>)
epoch_error 8.555382013320923


Finished epoch 38 of 100; error is 1.5926793813705444

recons tensor(1.6908, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0055, grad_fn=<SubBackward0>)
epoch_error 1.6962761878967285
recons tensor(1.6624, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0061, grad_fn=<SubBackward0>)
epoch_error 3.3647701740264893
recons tensor(1.7146, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0058, grad_fn=<SubBackward0>)
epoch_error 5.0852210521698
recons tensor(1.6909, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0057, grad_fn=<SubBackward0>)
epoch_error 6.78183126449585
recons tensor(1.6343, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0058, grad_fn=<SubBackward0>)
epoch_error 8.421950936317444


Finished epoch 39 of 100; error is 1.6401196718215942

recons tensor(1.6110, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0057, grad_fn=<SubBackward0>)
epoch_error 1.616735816001892
recons tensor(1.5785, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0055, grad_fn=<SubBackward0>)
epoch_error 3.2006964683532715
recons tensor(1.7198, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0060, grad_fn=<SubBackward0>)
epoch_error 4.926491618156433
recons tensor(1.7046, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0058, grad_fn=<SubBackward0>)
epoch_error 6.636882901191711
recons tensor(1.6498, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0058, grad_fn=<SubBackward0>)
epoch_error 8.292575001716614


Finished epoch 40 of 100; error is 1.6556921005249023

recons tensor(1.6399, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0055, grad_fn=<SubBackward0>)
epoch_error 1.6453555822372437
recons tensor(1.5839, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0058, grad_fn=<SubBackward0>)
epoch_error 3.234978437423706
recons tensor(1.5934, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0058, grad_fn=<SubBackward0>)
epoch_error 4.8341522216796875
recons tensor(1.6895, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0058, grad_fn=<SubBackward0>)
epoch_error 6.529394865036011
recons tensor(1.6213, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0056, grad_fn=<SubBackward0>)
epoch_error 8.156265377998352


Finished epoch 41 of 100; error is 1.6268705129623413

recons tensor(1.5245, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0058, grad_fn=<SubBackward0>)
epoch_error 1.5303267240524292
recons tensor(1.6144, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0054, grad_fn=<SubBackward0>)
epoch_error 3.1501975059509277
recons tensor(1.5649, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0053, grad_fn=<SubBackward0>)
epoch_error 4.720389485359192
recons tensor(1.6174, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0057, grad_fn=<SubBackward0>)
epoch_error 6.343476414680481
recons tensor(1.6998, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0058, grad_fn=<SubBackward0>)
epoch_error 8.049037337303162


Finished epoch 42 of 100; error is 1.7055609226226807

recons tensor(1.6351, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0058, grad_fn=<SubBackward0>)
epoch_error 1.6408330202102661
recons tensor(1.5587, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0055, grad_fn=<SubBackward0>)
epoch_error 3.2050634622573853
recons tensor(1.5725, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0052, grad_fn=<SubBackward0>)
epoch_error 4.782795429229736
recons tensor(1.6098, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0057, grad_fn=<SubBackward0>)
epoch_error 6.398322343826294
recons tensor(1.5355, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0054, grad_fn=<SubBackward0>)
epoch_error 7.939153432846069


Finished epoch 43 of 100; error is 1.5408310890197754

recons tensor(1.5485, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0054, grad_fn=<SubBackward0>)
epoch_error 1.5539103746414185
recons tensor(1.5881, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0057, grad_fn=<SubBackward0>)
epoch_error 3.1476480960845947
recons tensor(1.6156, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0054, grad_fn=<SubBackward0>)
epoch_error 4.768692374229431
recons tensor(1.4938, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0051, grad_fn=<SubBackward0>)
epoch_error 6.2676485776901245
recons tensor(1.5464, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0050, grad_fn=<SubBackward0>)
epoch_error 7.8190754652023315


Finished epoch 44 of 100; error is 1.551426887512207

recons tensor(1.5857, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0058, grad_fn=<SubBackward0>)
epoch_error 1.5915169715881348
recons tensor(1.5726, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0056, grad_fn=<SubBackward0>)
epoch_error 3.169658064842224
recons tensor(1.5318, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0050, grad_fn=<SubBackward0>)
epoch_error 4.706381440162659
recons tensor(1.5375, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0053, grad_fn=<SubBackward0>)
epoch_error 6.249216794967651
recons tensor(1.4480, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0054, grad_fn=<SubBackward0>)
epoch_error 7.7026742696762085


Finished epoch 45 of 100; error is 1.4534574747085571

recons tensor(1.4815, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0047, grad_fn=<SubBackward0>)
epoch_error 1.4861867427825928
recons tensor(1.5536, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0055, grad_fn=<SubBackward0>)
epoch_error 3.0452494621276855
recons tensor(1.6037, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0058, grad_fn=<SubBackward0>)
epoch_error 4.6547369956970215
recons tensor(1.4583, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0049, grad_fn=<SubBackward0>)
epoch_error 6.117981791496277
recons tensor(1.5126, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0057, grad_fn=<SubBackward0>)
epoch_error 7.636317849159241


Finished epoch 46 of 100; error is 1.5183360576629639

recons tensor(1.5334, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0050, grad_fn=<SubBackward0>)
epoch_error 1.5384418964385986
recons tensor(1.4211, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0054, grad_fn=<SubBackward0>)
epoch_error 2.9649899005889893
recons tensor(1.5071, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0052, grad_fn=<SubBackward0>)
epoch_error 4.477278470993042
recons tensor(1.5540, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0050, grad_fn=<SubBackward0>)
epoch_error 6.0363253355026245
recons tensor(1.4999, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0059, grad_fn=<SubBackward0>)
epoch_error 7.542112231254578


Finished epoch 47 of 100; error is 1.5057868957519531

recons tensor(1.5019, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0048, grad_fn=<SubBackward0>)
epoch_error 1.5066876411437988
recons tensor(1.4889, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0054, grad_fn=<SubBackward0>)
epoch_error 3.0009610652923584
recons tensor(1.4584, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0051, grad_fn=<SubBackward0>)
epoch_error 4.4644893407821655
recons tensor(1.4666, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0046, grad_fn=<SubBackward0>)
epoch_error 5.93569278717041
recons tensor(1.5355, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0052, grad_fn=<SubBackward0>)
epoch_error 7.476363897323608


Finished epoch 48 of 100; error is 1.5406711101531982

recons tensor(1.4441, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0047, grad_fn=<SubBackward0>)
epoch_error 1.4487782716751099
recons tensor(1.4281, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0045, grad_fn=<SubBackward0>)
epoch_error 2.8814356327056885
recons tensor(1.4976, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0056, grad_fn=<SubBackward0>)
epoch_error 4.384571552276611
recons tensor(1.5014, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0047, grad_fn=<SubBackward0>)
epoch_error 5.890730381011963
recons tensor(1.4632, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0050, grad_fn=<SubBackward0>)
epoch_error 7.358897686004639


Finished epoch 49 of 100; error is 1.4681673049926758

recons tensor(1.4492, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0050, grad_fn=<SubBackward0>)
epoch_error 1.454211950302124
recons tensor(1.5668, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0051, grad_fn=<SubBackward0>)
epoch_error 3.0260459184646606
recons tensor(1.4421, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0051, grad_fn=<SubBackward0>)
epoch_error 4.473284721374512
recons tensor(1.4213, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0046, grad_fn=<SubBackward0>)
epoch_error 5.899130344390869
recons tensor(1.3835, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0051, grad_fn=<SubBackward0>)
epoch_error 7.287721395492554


Finished epoch 50 of 100; error is 1.3885910511016846

recons tensor(1.4677, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0049, grad_fn=<SubBackward0>)
epoch_error 1.4725381135940552
recons tensor(1.4947, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0053, grad_fn=<SubBackward0>)
epoch_error 2.9725120067596436
recons tensor(1.4235, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0046, grad_fn=<SubBackward0>)
epoch_error 4.400584697723389
recons tensor(1.4266, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0049, grad_fn=<SubBackward0>)
epoch_error 5.832029461860657
recons tensor(1.3562, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0049, grad_fn=<SubBackward0>)
epoch_error 7.193139910697937


Finished epoch 51 of 100; error is 1.3611104488372803

recons tensor(1.3551, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0049, grad_fn=<SubBackward0>)
epoch_error 1.359947681427002
recons tensor(1.4393, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0049, grad_fn=<SubBackward0>)
epoch_error 2.804214596748352
recons tensor(1.4388, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0048, grad_fn=<SubBackward0>)
epoch_error 4.2478266954422
recons tensor(1.4177, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0043, grad_fn=<SubBackward0>)
epoch_error 5.669839143753052
recons tensor(1.4809, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0054, grad_fn=<SubBackward0>)
epoch_error 7.1561325788497925


Finished epoch 52 of 100; error is 1.4862934350967407

recons tensor(1.4459, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0051, grad_fn=<SubBackward0>)
epoch_error 1.4509804248809814
recons tensor(1.4401, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0045, grad_fn=<SubBackward0>)
epoch_error 2.8956152200698853
recons tensor(1.4212, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0050, grad_fn=<SubBackward0>)
epoch_error 4.32185161113739
recons tensor(1.3633, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0045, grad_fn=<SubBackward0>)
epoch_error 5.689688444137573
recons tensor(1.3766, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0047, grad_fn=<SubBackward0>)
epoch_error 7.0710049867630005


Finished epoch 53 of 100; error is 1.3813165426254272

recons tensor(1.4129, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0051, grad_fn=<SubBackward0>)
epoch_error 1.4179503917694092
recons tensor(1.4547, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0047, grad_fn=<SubBackward0>)
epoch_error 2.877370834350586
recons tensor(1.3587, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0049, grad_fn=<SubBackward0>)
epoch_error 4.241025447845459
recons tensor(1.3853, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0047, grad_fn=<SubBackward0>)
epoch_error 5.631072640419006
recons tensor(1.3615, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0047, grad_fn=<SubBackward0>)
epoch_error 6.997207045555115


Finished epoch 54 of 100; error is 1.3661344051361084

recons tensor(1.3100, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0046, grad_fn=<SubBackward0>)
epoch_error 1.3145809173583984
recons tensor(1.4407, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0049, grad_fn=<SubBackward0>)
epoch_error 2.7601295709609985
recons tensor(1.3234, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0045, grad_fn=<SubBackward0>)
epoch_error 4.0880115032196045
recons tensor(1.4160, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0049, grad_fn=<SubBackward0>)
epoch_error 5.508949875831604
recons tensor(1.4180, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0046, grad_fn=<SubBackward0>)
epoch_error 6.93160355091095


Finished epoch 55 of 100; error is 1.4226536750793457

recons tensor(1.3744, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0043, grad_fn=<SubBackward0>)
epoch_error 1.3787699937820435
recons tensor(1.3772, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0049, grad_fn=<SubBackward0>)
epoch_error 2.7608954906463623
recons tensor(1.3983, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0047, grad_fn=<SubBackward0>)
epoch_error 4.163925766944885
recons tensor(1.3576, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0045, grad_fn=<SubBackward0>)
epoch_error 5.525994420051575
recons tensor(1.3399, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0049, grad_fn=<SubBackward0>)
epoch_error 6.870738863945007


Finished epoch 56 of 100; error is 1.3447444438934326

recons tensor(1.3413, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0046, grad_fn=<SubBackward0>)
epoch_error 1.3458172082901
recons tensor(1.3959, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0046, grad_fn=<SubBackward0>)
epoch_error 2.7463828325271606
recons tensor(1.3165, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0043, grad_fn=<SubBackward0>)
epoch_error 4.067120432853699
recons tensor(1.3681, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0042, grad_fn=<SubBackward0>)
epoch_error 5.4393638372421265
recons tensor(1.3858, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0046, grad_fn=<SubBackward0>)
epoch_error 6.829774498939514


Finished epoch 57 of 100; error is 1.3904106616973877

recons tensor(1.3343, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0046, grad_fn=<SubBackward0>)
epoch_error 1.3388752937316895
recons tensor(1.3437, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0045, grad_fn=<SubBackward0>)
epoch_error 2.6870498657226562
recons tensor(1.3765, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0048, grad_fn=<SubBackward0>)
epoch_error 4.0683735609054565
recons tensor(1.3736, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0041, grad_fn=<SubBackward0>)
epoch_error 5.44602644443512
recons tensor(1.3549, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0046, grad_fn=<SubBackward0>)
epoch_error 6.805570960044861


Finished epoch 58 of 100; error is 1.3595445156097412

recons tensor(1.3826, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0045, grad_fn=<SubBackward0>)
epoch_error 1.3871006965637207
recons tensor(1.3934, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0048, grad_fn=<SubBackward0>)
epoch_error 2.7852678298950195
recons tensor(1.2416, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0045, grad_fn=<SubBackward0>)
epoch_error 4.031317114830017
recons tensor(1.3343, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0040, grad_fn=<SubBackward0>)
epoch_error 5.369614601135254
recons tensor(1.3773, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0044, grad_fn=<SubBackward0>)
epoch_error 6.751359820365906


Finished epoch 59 of 100; error is 1.3817452192306519

recons tensor(1.3393, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0047, grad_fn=<SubBackward0>)
epoch_error 1.3439596891403198
recons tensor(1.3625, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0043, grad_fn=<SubBackward0>)
epoch_error 2.710740327835083
recons tensor(1.3124, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0045, grad_fn=<SubBackward0>)
epoch_error 4.027634501457214
recons tensor(1.3279, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0045, grad_fn=<SubBackward0>)
epoch_error 5.35998809337616
recons tensor(1.3471, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0043, grad_fn=<SubBackward0>)
epoch_error 6.711474061012268


Finished epoch 60 of 100; error is 1.3514859676361084

recons tensor(1.3190, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0048, grad_fn=<SubBackward0>)
epoch_error 1.3237813711166382
recons tensor(1.3313, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0044, grad_fn=<SubBackward0>)
epoch_error 2.6595051288604736
recons tensor(1.3250, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0042, grad_fn=<SubBackward0>)
epoch_error 3.9886765480041504
recons tensor(1.3342, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0046, grad_fn=<SubBackward0>)
epoch_error 5.327410578727722
recons tensor(1.3108, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0043, grad_fn=<SubBackward0>)
epoch_error 6.642445921897888


Finished epoch 61 of 100; error is 1.315035343170166

recons tensor(1.3652, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0043, grad_fn=<SubBackward0>)
epoch_error 1.3694581985473633
recons tensor(1.3088, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0046, grad_fn=<SubBackward0>)
epoch_error 2.6828978061676025
recons tensor(1.2728, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0039, grad_fn=<SubBackward0>)
epoch_error 3.9596080780029297
recons tensor(1.3411, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0044, grad_fn=<SubBackward0>)
epoch_error 5.30511200428009
recons tensor(1.2897, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0042, grad_fn=<SubBackward0>)
epoch_error 6.599037528038025


Finished epoch 62 of 100; error is 1.2939255237579346

recons tensor(1.2621, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0042, grad_fn=<SubBackward0>)
epoch_error 1.2663027048110962
recons tensor(1.2700, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0039, grad_fn=<SubBackward0>)
epoch_error 2.540164589881897
recons tensor(1.2911, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0041, grad_fn=<SubBackward0>)
epoch_error 3.835359811782837
recons tensor(1.4014, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0044, grad_fn=<SubBackward0>)
epoch_error 5.241181492805481
recons tensor(1.3111, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0047, grad_fn=<SubBackward0>)
epoch_error 6.5569645166397095


Finished epoch 63 of 100; error is 1.3157830238342285

recons tensor(1.2941, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0041, grad_fn=<SubBackward0>)
epoch_error 1.2981661558151245
recons tensor(1.3102, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0044, grad_fn=<SubBackward0>)
epoch_error 2.6127554178237915
recons tensor(1.3201, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0045, grad_fn=<SubBackward0>)
epoch_error 3.93729567527771
recons tensor(1.2548, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0040, grad_fn=<SubBackward0>)
epoch_error 5.196174740791321
recons tensor(1.3196, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0045, grad_fn=<SubBackward0>)
epoch_error 6.520286321640015


Finished epoch 64 of 100; error is 1.3241115808486938

recons tensor(1.3268, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0043, grad_fn=<SubBackward0>)
epoch_error 1.3311423063278198
recons tensor(1.3599, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0040, grad_fn=<SubBackward0>)
epoch_error 2.6950442790985107
recons tensor(1.2669, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0044, grad_fn=<SubBackward0>)
epoch_error 3.9663405418395996
recons tensor(1.2645, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0040, grad_fn=<SubBackward0>)
epoch_error 5.234789967536926
recons tensor(1.2304, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 6.469070672988892


Finished epoch 65 of 100; error is 1.2342807054519653

recons tensor(1.2998, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0045, grad_fn=<SubBackward0>)
epoch_error 1.3043212890625
recons tensor(1.3371, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 2.645253539085388
recons tensor(1.2595, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0043, grad_fn=<SubBackward0>)
epoch_error 3.9090837240219116
recons tensor(1.2612, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0043, grad_fn=<SubBackward0>)
epoch_error 5.174520969390869
recons tensor(1.2602, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0042, grad_fn=<SubBackward0>)
epoch_error 6.438985586166382


Finished epoch 66 of 100; error is 1.2644646167755127

recons tensor(1.2884, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0041, grad_fn=<SubBackward0>)
epoch_error 1.2925673723220825
recons tensor(1.2650, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0037, grad_fn=<SubBackward0>)
epoch_error 2.5612164735794067
recons tensor(1.2391, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0041, grad_fn=<SubBackward0>)
epoch_error 3.804492473602295
recons tensor(1.3481, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0041, grad_fn=<SubBackward0>)
epoch_error 5.156708836555481
recons tensor(1.2404, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0041, grad_fn=<SubBackward0>)
epoch_error 6.401188373565674


Finished epoch 67 of 100; error is 1.2444795370101929

recons tensor(1.2512, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0042, grad_fn=<SubBackward0>)
epoch_error 1.2553900480270386
recons tensor(1.3323, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0042, grad_fn=<SubBackward0>)
epoch_error 2.591899275779724
recons tensor(1.2640, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0045, grad_fn=<SubBackward0>)
epoch_error 3.8603808879852295
recons tensor(1.1787, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 5.042708039283752
recons tensor(1.3019, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0041, grad_fn=<SubBackward0>)
epoch_error 6.348739385604858


Finished epoch 68 of 100; error is 1.306031346321106

recons tensor(1.2358, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0042, grad_fn=<SubBackward0>)
epoch_error 1.2399444580078125
recons tensor(1.2109, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0039, grad_fn=<SubBackward0>)
epoch_error 2.454713821411133
recons tensor(1.3427, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0042, grad_fn=<SubBackward0>)
epoch_error 3.801619291305542
recons tensor(1.2934, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0041, grad_fn=<SubBackward0>)
epoch_error 5.099117755889893
recons tensor(1.2020, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0037, grad_fn=<SubBackward0>)
epoch_error 6.304817080497742


Finished epoch 69 of 100; error is 1.2056993246078491

recons tensor(1.2884, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0045, grad_fn=<SubBackward0>)
epoch_error 1.292880892753601
recons tensor(1.2408, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0042, grad_fn=<SubBackward0>)
epoch_error 2.5378791093826294
recons tensor(1.2164, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0042, grad_fn=<SubBackward0>)
epoch_error 3.758521556854248
recons tensor(1.3014, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0040, grad_fn=<SubBackward0>)
epoch_error 5.063915014266968
recons tensor(1.2192, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 6.286738634109497


Finished epoch 70 of 100; error is 1.2228236198425293

recons tensor(1.2368, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0039, grad_fn=<SubBackward0>)
epoch_error 1.2406609058380127
recons tensor(1.2941, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0039, grad_fn=<SubBackward0>)
epoch_error 2.5386762619018555
recons tensor(1.2635, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0041, grad_fn=<SubBackward0>)
epoch_error 3.806307554244995
recons tensor(1.2197, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0040, grad_fn=<SubBackward0>)
epoch_error 5.03000271320343
recons tensor(1.2388, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0039, grad_fn=<SubBackward0>)
epoch_error 6.272638201713562


Finished epoch 71 of 100; error is 1.2426354885101318

recons tensor(1.2706, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0043, grad_fn=<SubBackward0>)
epoch_error 1.2749314308166504
recons tensor(1.2176, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 2.4963042736053467
recons tensor(1.2579, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 3.7580429315567017
recons tensor(1.2507, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0041, grad_fn=<SubBackward0>)
epoch_error 5.01288366317749
recons tensor(1.2085, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 6.225198149681091


Finished epoch 72 of 100; error is 1.212314486503601

recons tensor(1.2570, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0042, grad_fn=<SubBackward0>)
epoch_error 1.2611587047576904
recons tensor(1.2205, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0040, grad_fn=<SubBackward0>)
epoch_error 2.4855931997299194
recons tensor(1.2925, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0039, grad_fn=<SubBackward0>)
epoch_error 3.7820099592208862
recons tensor(1.2022, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0040, grad_fn=<SubBackward0>)
epoch_error 4.988181114196777
recons tensor(1.2125, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 6.204198479652405


Finished epoch 73 of 100; error is 1.2160173654556274

recons tensor(1.2796, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0040, grad_fn=<SubBackward0>)
epoch_error 1.2836110591888428
recons tensor(1.1954, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0039, grad_fn=<SubBackward0>)
epoch_error 2.4828710556030273
recons tensor(1.2046, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 3.691235899925232
recons tensor(1.2039, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0040, grad_fn=<SubBackward0>)
epoch_error 4.899105191230774
recons tensor(1.2714, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0043, grad_fn=<SubBackward0>)
epoch_error 6.174710512161255


Finished epoch 74 of 100; error is 1.275605320930481

recons tensor(1.3046, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0039, grad_fn=<SubBackward0>)
epoch_error 1.3085187673568726
recons tensor(1.2031, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0040, grad_fn=<SubBackward0>)
epoch_error 2.515652298927307
recons tensor(1.2569, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 3.7764012813568115
recons tensor(1.1538, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0037, grad_fn=<SubBackward0>)
epoch_error 4.933871269226074
recons tensor(1.1998, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 6.137318849563599


Finished epoch 75 of 100; error is 1.2034475803375244

recons tensor(1.2432, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0043, grad_fn=<SubBackward0>)
epoch_error 1.2475560903549194
recons tensor(1.2086, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 2.4594615697860718
recons tensor(1.2406, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0044, grad_fn=<SubBackward0>)
epoch_error 3.704398274421692
recons tensor(1.2773, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0034, grad_fn=<SubBackward0>)
epoch_error 4.985068082809448
recons tensor(1.2311, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0037, grad_fn=<SubBackward0>)
epoch_error 6.21990168094635


Finished epoch 76 of 100; error is 1.2348335981369019

recons tensor(1.2159, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0042, grad_fn=<SubBackward0>)
epoch_error 1.2200253009796143
recons tensor(1.2971, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0034, grad_fn=<SubBackward0>)
epoch_error 2.5205700397491455
recons tensor(1.2268, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0042, grad_fn=<SubBackward0>)
epoch_error 3.75152587890625
recons tensor(1.2891, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0037, grad_fn=<SubBackward0>)
epoch_error 5.044314861297607
recons tensor(1.2495, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0039, grad_fn=<SubBackward0>)
epoch_error 6.297725439071655


Finished epoch 77 of 100; error is 1.2534105777740479

recons tensor(1.2983, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0044, grad_fn=<SubBackward0>)
epoch_error 1.3026591539382935
recons tensor(1.2005, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 2.5066466331481934
recons tensor(1.1991, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 3.7092978954315186
recons tensor(1.2061, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0040, grad_fn=<SubBackward0>)
epoch_error 4.919402480125427
recons tensor(1.2430, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 6.165684223175049


Finished epoch 78 of 100; error is 1.2462817430496216

recons tensor(1.2308, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0041, grad_fn=<SubBackward0>)
epoch_error 1.234929084777832
recons tensor(1.2114, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0037, grad_fn=<SubBackward0>)
epoch_error 2.450026035308838
recons tensor(1.1700, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 3.6234915256500244
recons tensor(1.2905, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0045, grad_fn=<SubBackward0>)
epoch_error 4.918450832366943
recons tensor(1.1821, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 6.104156255722046


Finished epoch 79 of 100; error is 1.1857054233551025

recons tensor(1.1768, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 1.180346965789795
recons tensor(1.1859, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 2.370029091835022
recons tensor(1.2212, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0039, grad_fn=<SubBackward0>)
epoch_error 3.595150351524353
recons tensor(1.2511, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0037, grad_fn=<SubBackward0>)
epoch_error 4.849909663200378
recons tensor(1.1732, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0039, grad_fn=<SubBackward0>)
epoch_error 6.02708113193512


Finished epoch 80 of 100; error is 1.1771714687347412

recons tensor(1.2007, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0034, grad_fn=<SubBackward0>)
epoch_error 1.2041071653366089
recons tensor(1.2498, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0039, grad_fn=<SubBackward0>)
epoch_error 2.4578804969787598
recons tensor(1.1345, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0037, grad_fn=<SubBackward0>)
epoch_error 3.596091389656067
recons tensor(1.2065, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 4.806244134902954
recons tensor(1.1860, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0040, grad_fn=<SubBackward0>)
epoch_error 5.996174573898315


Finished epoch 81 of 100; error is 1.1899304389953613

recons tensor(1.2153, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0039, grad_fn=<SubBackward0>)
epoch_error 1.2191592454910278
recons tensor(1.2106, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 2.433485984802246
recons tensor(1.1779, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0037, grad_fn=<SubBackward0>)
epoch_error 3.6150903701782227
recons tensor(1.1518, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 4.770485758781433
recons tensor(1.1999, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 5.973803758621216


Finished epoch 82 of 100; error is 1.2033179998397827

recons tensor(1.1556, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 1.1594643592834473
recons tensor(1.1665, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 2.3295936584472656
recons tensor(1.1917, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 3.5250678062438965
recons tensor(1.2142, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0037, grad_fn=<SubBackward0>)
epoch_error 4.742988467216492
recons tensor(1.2044, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 5.951048254966736


Finished epoch 83 of 100; error is 1.2080597877502441

recons tensor(1.1512, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 1.1550523042678833
recons tensor(1.1670, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 2.3255796432495117
recons tensor(1.2028, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 3.5318620204925537
recons tensor(1.1688, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 4.704393744468689
recons tensor(1.2532, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 5.961037993431091


Finished epoch 84 of 100; error is 1.2566442489624023

recons tensor(1.1881, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0042, grad_fn=<SubBackward0>)
epoch_error 1.1922862529754639
recons tensor(1.2688, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 2.4646923542022705
recons tensor(1.1747, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 3.6426138877868652
recons tensor(1.1728, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 4.819235920906067
recons tensor(1.1543, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0029, grad_fn=<SubBackward0>)
epoch_error 5.97642195224762


Finished epoch 85 of 100; error is 1.1571860313415527

recons tensor(1.1953, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0037, grad_fn=<SubBackward0>)
epoch_error 1.1989507675170898
recons tensor(1.2188, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 2.4216123819351196
recons tensor(1.1983, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 3.6234368085861206
recons tensor(1.1631, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0040, grad_fn=<SubBackward0>)
epoch_error 4.79051661491394
recons tensor(1.1564, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 5.950397968292236


Finished epoch 86 of 100; error is 1.159881353378296

recons tensor(1.2266, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 1.2301454544067383
recons tensor(1.1907, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0037, grad_fn=<SubBackward0>)
epoch_error 2.4245569705963135
recons tensor(1.1700, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0034, grad_fn=<SubBackward0>)
epoch_error 3.5979418754577637
recons tensor(1.1762, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0034, grad_fn=<SubBackward0>)
epoch_error 4.777545928955078
recons tensor(1.1690, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 5.950383901596069


Finished epoch 87 of 100; error is 1.1728379726409912

recons tensor(1.1722, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 1.175477147102356
recons tensor(1.2296, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 2.4086220264434814
recons tensor(1.1532, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0040, grad_fn=<SubBackward0>)
epoch_error 3.5658475160598755
recons tensor(1.2182, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0031, grad_fn=<SubBackward0>)
epoch_error 4.78717827796936
recons tensor(1.1708, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0037, grad_fn=<SubBackward0>)
epoch_error 5.961740732192993


Finished epoch 88 of 100; error is 1.1745624542236328

recons tensor(1.1996, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0032, grad_fn=<SubBackward0>)
epoch_error 1.2028169631958008
recons tensor(1.1276, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 2.3337059020996094
recons tensor(1.2397, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0039, grad_fn=<SubBackward0>)
epoch_error 3.5773309469223022
recons tensor(1.1417, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0034, grad_fn=<SubBackward0>)
epoch_error 4.722424387931824
recons tensor(1.1815, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0032, grad_fn=<SubBackward0>)
epoch_error 5.907139420509338


Finished epoch 89 of 100; error is 1.1847150325775146

recons tensor(1.1755, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 1.179269552230835
recons tensor(1.2272, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0034, grad_fn=<SubBackward0>)
epoch_error 2.409837245941162
recons tensor(1.1623, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0032, grad_fn=<SubBackward0>)
epoch_error 3.575329065322876
recons tensor(1.1827, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0040, grad_fn=<SubBackward0>)
epoch_error 4.762019515037537
recons tensor(1.1517, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 5.917041778564453


Finished epoch 90 of 100; error is 1.1550222635269165

recons tensor(1.1451, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 1.1485275030136108
recons tensor(1.2305, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 2.3828152418136597
recons tensor(1.1316, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0032, grad_fn=<SubBackward0>)
epoch_error 3.517604112625122
recons tensor(1.1502, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 4.671309113502502
recons tensor(1.1694, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 5.844482183456421


Finished epoch 91 of 100; error is 1.1731730699539185

recons tensor(1.2007, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 1.2039989233016968
recons tensor(1.1450, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 2.352626323699951
recons tensor(1.1354, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 3.4913586378097534
recons tensor(1.1475, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0034, grad_fn=<SubBackward0>)
epoch_error 4.642295241355896
recons tensor(1.1781, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 5.8236987590789795


Finished epoch 92 of 100; error is 1.1814035177230835

recons tensor(1.1775, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 1.1808576583862305
recons tensor(1.1104, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0032, grad_fn=<SubBackward0>)
epoch_error 2.2944531440734863
recons tensor(1.1826, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0038, grad_fn=<SubBackward0>)
epoch_error 3.48089337348938
recons tensor(1.1539, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0032, grad_fn=<SubBackward0>)
epoch_error 4.637947201728821
recons tensor(1.1698, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0034, grad_fn=<SubBackward0>)
epoch_error 5.811140060424805


Finished epoch 93 of 100; error is 1.1731928586959839

recons tensor(1.1505, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 1.154110074043274
recons tensor(1.1327, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0031, grad_fn=<SubBackward0>)
epoch_error 2.289980411529541
recons tensor(1.1636, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 3.4571499824523926
recons tensor(1.1400, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 4.600584506988525
recons tensor(1.1892, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0032, grad_fn=<SubBackward0>)
epoch_error 5.792898058891296


Finished epoch 94 of 100; error is 1.192313551902771

recons tensor(1.1453, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 1.1488741636276245
recons tensor(1.1947, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0037, grad_fn=<SubBackward0>)
epoch_error 2.347267985343933
recons tensor(1.1346, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 3.485417604446411
recons tensor(1.1385, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0034, grad_fn=<SubBackward0>)
epoch_error 4.627283692359924
recons tensor(1.1130, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 5.7435808181762695


Finished epoch 95 of 100; error is 1.1162971258163452

recons tensor(1.1141, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 1.1176940202713013
recons tensor(1.1792, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0034, grad_fn=<SubBackward0>)
epoch_error 2.3003008365631104
recons tensor(1.1385, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0032, grad_fn=<SubBackward0>)
epoch_error 3.442004084587097
recons tensor(1.1232, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 4.568485260009766
recons tensor(1.1659, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 5.737846374511719


Finished epoch 96 of 100; error is 1.1693611145019531

recons tensor(1.1322, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 1.135506272315979
recons tensor(1.1330, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 2.2718470096588135
recons tensor(1.1377, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0032, grad_fn=<SubBackward0>)
epoch_error 3.412772297859192
recons tensor(1.1708, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 4.587159276008606
recons tensor(1.1140, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 5.7044278383255005


Finished epoch 97 of 100; error is 1.1172685623168945

recons tensor(1.1596, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 1.1629023551940918
recons tensor(1.0998, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0031, grad_fn=<SubBackward0>)
epoch_error 2.2657965421676636
recons tensor(1.1406, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0032, grad_fn=<SubBackward0>)
epoch_error 3.4095548391342163
recons tensor(1.1279, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 4.540984272956848
recons tensor(1.1406, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 5.684915900230408


Finished epoch 98 of 100; error is 1.1439316272735596

recons tensor(1.1588, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0032, grad_fn=<SubBackward0>)
epoch_error 1.1620447635650635
recons tensor(1.1162, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 2.281572103500366
recons tensor(1.1117, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0034, grad_fn=<SubBackward0>)
epoch_error 3.3965859413146973
recons tensor(1.1317, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 4.531588673591614
recons tensor(1.1346, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0036, grad_fn=<SubBackward0>)
epoch_error 5.66979455947876


Finished epoch 99 of 100; error is 1.138205885887146

recons tensor(1.1004, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0030, grad_fn=<SubBackward0>)
epoch_error 1.1034210920333862
recons tensor(1.1148, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0032, grad_fn=<SubBackward0>)
epoch_error 2.2213855981826782
recons tensor(1.1123, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0032, grad_fn=<SubBackward0>)
epoch_error 3.336967349052429
recons tensor(1.1599, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0033, grad_fn=<SubBackward0>)
epoch_error 4.500114560127258
recons tensor(1.1467, grad_fn=<MseLossBackward>)
mmd_err tensor(0.0035, grad_fn=<SubBackward0>)
epoch_error 5.650385499000549


Finished epoch 100 of 100; error is 1.150270938873291

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,40,41,42,43,44,45,46,47,48,49
!,0.860827,-0.529784,-0.875025,0.191395,-0.584016,0.104198,-0.872248,-0.426214,-0.720576,0.332389,...,-0.080874,0.140411,0.024004,0.685998,-0.704809,-0.743674,0.613740,0.693404,0.415883,0.500870
);,0.354351,0.214322,-0.779751,-0.492412,0.709453,-0.960830,-0.248109,0.479180,-0.841875,0.116841,...,-0.562501,-0.382599,-0.193444,-0.634701,-0.888504,-0.584067,0.863589,-0.689983,0.366530,0.505097
.,0.106176,0.153307,0.026138,0.125620,0.035706,-0.090320,0.028480,-0.102791,-0.121290,-0.229664,...,0.052002,0.005492,0.190977,-0.107113,-0.073235,-0.091404,0.122880,-0.007660,-0.004925,0.011887
...,0.389390,-0.070074,0.080465,-0.010523,-0.325886,0.143089,-0.611279,-0.406052,-0.420833,0.301197,...,-0.404950,-0.000491,0.153722,0.407269,-0.399067,-0.243941,0.239297,0.160682,-0.502048,0.231176
;p,-0.035943,-0.022684,0.251852,0.476309,-0.039681,0.087229,-0.172907,0.122161,-0.202519,-0.036543,...,-0.124706,-0.107030,-0.033157,-0.070131,-0.096281,0.077372,0.030573,-0.097830,-0.021083,0.320190
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
zebra,0.750245,0.383864,-0.572024,-0.762931,-0.653763,0.344779,0.005910,0.614450,-0.878966,-0.613693,...,-0.859222,-0.144411,-0.770528,0.082004,-0.713381,-0.848060,0.920986,0.323909,0.847942,0.832436
zinc,0.556657,-0.399992,0.090572,-0.063096,-0.102184,0.656546,0.885219,-0.460375,-0.753899,-0.180994,...,-0.447213,0.449200,-0.557913,-0.442591,-0.358168,0.141244,-0.486856,0.698193,-0.070678,0.780409
zombie,0.619725,-0.396205,-0.841874,-0.424807,-0.669406,-0.610342,-0.799636,-0.795378,-0.450646,-0.198200,...,-0.818874,0.737725,-0.410727,0.919908,-0.181519,-0.901532,0.822892,-0.691797,0.228586,0.184088
zone,-0.286202,-0.196985,-0.523980,-0.445213,0.300156,0.599161,0.291036,0.499800,0.064825,0.115376,...,-0.423696,0.276401,0.586188,0.183353,-0.216262,-0.683135,0.422102,-0.291242,-0.683411,0.458345


In [53]:
from collections import defaultdict
from nltk.corpus import wordnet as wn
import numpy as np
import os
import pandas as pd
import retrofitting
from retrofitting import Retrofitter
import utils


def get_wordnet_edges():
    edges = defaultdict(set)
    for ss in wn.all_synsets():
        lem_names = {lem.name() for lem in ss.lemmas()}
        for lem in lem_names:
            edges[lem] |= lem_names
    return edges

def convert_edges_to_indices(edges, Q):
    lookup = dict(zip(Q.index, range(Q.shape[0])))
    index_edges = defaultdict(set)
    for start, finish_nodes in edges.items():
        s = lookup.get(start)
        if s:
            f = {lookup[n] for n in finish_nodes if n in lookup}
            if f:
                index_edges[s] = f
    return index_edges

wn_edges = get_wordnet_edges()
wn_index_edges = convert_edges_to_indices(wn_edges, enriched_neighbours_ppmi_lsa_k_ae)
wn_retro = Retrofitter(verbose=True)
X_retro = wn_retro.fit(enriched_neighbours_ppmi_lsa_k_ae, wn_index_edges)
#print("X_retro")
#print(X_retro)
neighbours_retro = vsm.neighbors(test_word, X_retro).head()
print("neighbours_retro")
print(neighbours_retro)

Converged at iteration 8; change was 0.0044 

neighbours_retro
finance        0.000000
investment     0.114872
banking        0.126206
investments    0.139357
interest       0.173808
dtype: float64


neighbours_enriched
finance        0.000000
investment     0.114872
banking        0.126206
investments    0.139357
monetary       0.182607
dtype: float64


In [None]:
# On an otherwise blank line in this cell, please enter
# your "Macro-average" value as reported by the code above. 
# Please enter only a number between 0 and 1 inclusive.
# Please do not remove this comment.
if 'IS_GRADESCOPE_ENV' not in os.environ:
    pass
    # Please enter your score in the scope of the above conditional.
    ##### YOUR CODE HERE


