In [1]:
import argparse
import math
import os
import sys
from pathlib import Path
from model_t5 import T5VAE
import pretty_errors
import pytorch_lightning as pl
import torch
import torch.nn as nn
from base_models import datasets, models
from generate import generate
from omegaconf import OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins import DDPPlugin
from torch.utils.data import DataLoader
from torch.utils.data import SubsetRandomSampler
from transformers import T5TokenizerFast
import numpy as np

In [6]:
checkpoint_path = "epoch=9-step=999999.ckpt"

In [7]:
tokenizer = T5TokenizerFast.from_pretrained("t5-small")

In [8]:
tokenizer.encode("i can sing")

[3, 23, 54, 10159, 1]

In [9]:
model = T5VAE.load_from_checkpoint(
            checkpoint_path,
            strict=False,
            map_location="cpu",
            tokenizer=tokenizer,
            iterations_per_training_epoch=None,
            latent_dim=32,
            pooling_strategy=max,
            fixed_reg_weight=None,
            denoise_percentage=0.4,
            base_model="t5-small",
        )
model.eval()

T5VAE(
  (t5): ModifiedT5ForConditionalGeneration(
    (shared): Embedding(32128, 512)
    (encoder): T5Stack(
      (embed_tokens): Embedding(32128, 512)
      (block): ModuleList(
        (0): T5Block(
          (layer): ModuleList(
            (0): T5LayerSelfAttention(
              (SelfAttention): T5Attention(
                (q): Linear(in_features=512, out_features=512, bias=False)
                (k): Linear(in_features=512, out_features=512, bias=False)
                (v): Linear(in_features=512, out_features=512, bias=False)
                (o): Linear(in_features=512, out_features=512, bias=False)
                (relative_attention_bias): Embedding(32, 8)
              )
              (layer_norm): T5LayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (1): T5LayerFF(
              (DenseReluDense): T5DenseReluDense(
                (wi): Linear(in_features=512, out_features=2048, bias=False)
                (wo): Linear(in_features

In [41]:
model = T5VAE(tokenizer=tokenizer,iterations_per_training_epoch=None, latent_dim=32, pooling_strategy=max)

In [45]:
model.state_dict=torch.load("agg_finish.pt")

In [13]:
for param in model.t5.encoder.parameters():
    param.requires_grad = False
for param in model.t5.mu.parameters():
    param.requires_grad = False
for param in model.t5.logvar.parameters():
    param.requires_grad = False

In [14]:
for param in model.t5.parameters():
    print(param)

Parameter containing:
tensor([[ -1.8553,   0.2403,  -6.6939,  ...,  -0.2396,   2.5208,  -2.5168],
        [ 11.7849,   7.8757, -11.0178,  ...,   7.3823,  -6.7761,   0.6960],
        [ -7.9562,   7.4346,  25.7959,  ..., -24.7404,   1.0139,  -2.1245],
        ...,
        [-23.7003, -26.7507, -16.1771,  ..., -16.6771,  -4.9239,  25.7007],
        [-23.9440, -27.5745, -17.1404,  ..., -16.6771,  -4.5145,  26.0507],
        [-25.1047, -26.6340, -16.8021,  ..., -17.3654,  -6.5710,  25.9340]])
Parameter containing:
tensor([[-0.0185, -0.0705,  0.0058,  ...,  0.0244, -0.0350, -0.1119],
        [-0.0336, -0.0199,  0.0452,  ..., -0.0392, -0.0458,  0.0628],
        [ 0.0489, -0.0558,  0.0079,  ..., -0.0333, -0.0164, -0.0397],
        ...,
        [-0.0848,  0.0111, -0.1057,  ..., -0.0060, -0.0937,  0.0374],
        [ 0.0027,  0.0160, -0.0389,  ..., -0.0899,  0.0266,  0.0024],
        [-0.1017, -0.0027, -0.0020,  ...,  0.0489, -0.0398, -0.0490]])
Parameter containing:
tensor([[ 0.0692,  0.0934, -0.

        0.0775, 0.1105, 0.1145, 0.0738, 0.0927, 0.0425, 0.0535, 0.1053])
Parameter containing:
tensor([[-0.6443, -1.2655,  0.6087,  ..., -0.1978,  0.4557, -0.1418],
        [-0.0553, -0.3005, -0.1345,  ...,  0.3921,  0.1705,  0.4883],
        [ 0.6043,  0.2584,  0.3910,  ...,  0.4346, -0.1439,  0.0794],
        ...,
        [-0.5624,  0.0734, -0.6544,  ...,  0.8381, -0.1813,  0.4375],
        [-0.5636, -0.0116, -0.1201,  ..., -0.3393,  0.2246,  0.4744],
        [-0.2645, -0.3469,  0.1868,  ..., -0.0071,  0.6514, -0.0235]])
Parameter containing:
tensor([[-0.0528, -0.1267, -0.3040,  ..., -0.0286,  0.0270, -0.2928],
        [-0.3146,  0.2591,  0.0261,  ..., -0.1690, -0.1001, -0.2655],
        [ 0.2406,  0.0124, -0.1996,  ..., -0.7112,  0.0109, -0.2921],
        ...,
        [-0.1658, -0.3112, -0.1988,  ...,  0.1881, -0.0172, -0.2526],
        [-0.0832,  0.0362,  0.1085,  ...,  0.0222, -0.0466,  0.4242],
        [ 0.3856,  0.1533,  0.0356,  ..., -0.2136,  0.1036,  0.0482]])
Parameter conta

        0.0803, 0.1006, 0.1185, 0.0519, 0.1154, 0.1027, 0.0779, 0.0740])
Parameter containing:
tensor([[-0.3006,  0.9351,  0.0174,  ..., -0.1886, -0.1352, -0.7582],
        [ 0.3010,  0.5686, -0.2801,  ...,  0.8410, -0.3711, -0.7863],
        [-0.0406, -0.5383, -0.4342,  ..., -0.3133,  0.9302,  0.9275],
        ...,
        [ 0.6424,  0.7994,  1.6601,  ..., -0.9898, -0.1725, -0.7040],
        [-0.1827,  0.6042,  0.0510,  ...,  0.8619, -0.0599, -0.8155],
        [-0.6936, -0.5533, -0.7582,  ..., -0.5428,  0.4523, -0.2024]])
Parameter containing:
tensor([[-0.0364,  0.1612, -0.0824,  ..., -0.1891,  0.0045, -0.0237],
        [-0.2495, -0.0169, -0.1836,  ..., -0.6322, -0.2056, -0.3367],
        [-0.0112,  0.5351,  0.5637,  ..., -0.4046, -0.0693, -0.0625],
        ...,
        [-0.2081,  0.2617, -0.1680,  ...,  0.0990, -0.2670, -0.2722],
        [-0.5239, -0.0479, -0.3257,  ...,  0.3313,  0.1728, -0.2527],
        [ 0.2777,  0.3181,  0.0102,  ..., -0.2574, -0.0823,  0.1305]])
Parameter conta

         0.1365,  0.1377,  0.1152,  0.1454,  0.1618,  0.1194,  0.1107,  0.1518])
Parameter containing:
tensor([[-3.6774e-01,  6.0897e-02, -6.1379e-01,  ...,  3.0108e-01,
         -1.3294e-01, -1.0781e+00],
        [ 4.2299e-01,  7.9530e-01, -1.7122e-03,  ..., -1.4717e-01,
         -7.3636e-01, -8.0370e-02],
        [-9.6510e-02,  2.7976e-01,  6.6255e-01,  ...,  2.8287e-01,
         -7.4477e-01,  3.7566e-01],
        ...,
        [ 1.8479e+00,  2.1971e-01,  1.2158e-02,  ...,  1.1450e+00,
          2.6769e-01,  9.0475e-01],
        [-1.0712e+00, -1.0501e-01,  5.3570e-01,  ...,  1.2769e+00,
          9.2395e-01, -9.0564e-01],
        [ 1.7329e-01, -5.0812e-01, -4.3018e-01,  ...,  1.6609e-01,
          4.7244e-01,  9.4982e-01]])
Parameter containing:
tensor([[ 0.1386,  0.4502, -0.4279,  ...,  0.6739, -0.4308, -0.1188],
        [-0.1419,  0.4063,  0.1231,  ...,  0.5004, -0.2661,  0.0168],
        [ 0.2029,  0.3024,  0.1403,  ...,  0.2425, -0.0116,  0.0215],
        ...,
        [-0.2158,  0

         0.1005,  0.0837,  0.0752,  0.1156,  0.0743,  0.1107,  0.1109,  0.0873])
Parameter containing:
tensor([[-1.6355,  0.9350,  0.7482,  ...,  0.0348,  0.3593, -0.1283],
        [-0.3691,  0.5102,  0.0618,  ...,  0.1641,  0.2266,  0.0697],
        [-0.2640, -0.8063,  0.1428,  ...,  0.1459,  0.8482,  0.6710],
        ...,
        [-1.1699, -0.8048, -1.1567,  ...,  1.2032,  1.5433,  0.4758],
        [ 0.8253, -0.3066,  0.6623,  ...,  0.3784, -0.8496,  0.3449],
        [ 0.3455, -0.3918,  0.5515,  ...,  0.7512, -0.1180,  0.2321]])
Parameter containing:
tensor([[-0.8469, -0.6464, -0.0104,  ...,  0.4577, -0.1762, -0.1609],
        [ 0.1217, -0.1607,  0.1024,  ..., -0.3974,  0.0340, -0.7133],
        [ 0.2081,  0.0546, -0.0232,  ...,  0.3595, -0.2652, -0.5510],
        ...,
        [-0.2237,  0.1622, -0.2424,  ..., -0.2647, -0.4217, -1.1020],
        [ 0.0600, -0.2620, -0.1397,  ...,  1.3369, -0.2560,  0.7447],
        [ 0.1440,  0.3085, -0.1402,  ..., -0.1854,  0.0976,  0.3013]])
Paramet

         0.1156,  0.1209,  0.0664,  0.0990,  0.1273,  0.1172,  0.0914,  0.0716])
Parameter containing:
tensor([[-0.2487,  0.6063,  0.2425,  ..., -0.5356, -1.2714,  0.1041],
        [-0.3138, -0.2694, -0.1845,  ...,  0.0636,  0.1884, -0.2633],
        [ 1.2088, -0.3002, -0.9106,  ...,  0.6809, -0.5554,  0.2540],
        ...,
        [-1.0273,  0.3316, -0.0888,  ..., -0.8908,  0.8060,  0.1002],
        [ 0.5184, -0.6128,  0.5610,  ..., -0.3648, -0.9972, -0.2731],
        [-0.5785, -0.0938,  0.9467,  ..., -0.4072, -0.4913,  0.4899]])
Parameter containing:
tensor([[-0.0806, -0.4601,  0.8760,  ...,  0.0631,  0.6717,  0.4110],
        [-0.3665, -0.4081, -1.1533,  ..., -0.1140,  0.0588, -0.1392],
        [-0.0591, -0.2079, -0.1587,  ..., -0.1962, -0.5612,  0.5315],
        ...,
        [ 0.4630,  0.3184, -0.2325,  ..., -0.2862,  0.6793,  0.3932],
        [ 0.2801,  0.1800,  0.3295,  ..., -0.0657,  0.0063,  0.5601],
        [-0.4188,  0.3499, -0.5096,  ...,  0.2012, -0.4699, -0.4655]])
Paramet

         0.1087,  0.0820,  0.0948,  0.0873,  0.0983,  0.1122,  0.1050,  0.1294])
Parameter containing:
tensor([[ 0.1540,  0.5506,  0.5402,  ...,  0.3761, -0.0582,  0.3474],
        [ 0.3524,  0.8677,  0.5371,  ..., -1.3181, -0.7008,  0.9980],
        [ 0.4420,  0.6882, -0.6487,  ..., -0.5053,  0.2565,  0.9141],
        ...,
        [-0.4775,  0.1445, -0.0151,  ..., -0.2138,  0.7183,  0.0774],
        [ 0.3630, -1.0558,  1.2283,  ..., -1.3914,  1.5316,  1.0662],
        [ 0.7808,  0.2055, -0.0975,  ..., -0.8616, -0.2133, -0.2597]])
Parameter containing:
tensor([[-0.3852, -0.4624,  0.1450,  ...,  0.0268,  0.2928,  0.8180],
        [-0.3874,  0.0161,  0.0023,  ..., -0.2145, -1.0316,  0.7267],
        [ 0.0683,  0.3135,  0.3825,  ..., -0.1595, -0.4960,  0.3613],
        ...,
        [ 0.1512, -0.1703,  0.2230,  ..., -0.3944,  0.0562, -0.2528],
        [-0.1784,  0.1337, -0.0627,  ..., -0.4324,  0.4343, -0.3094],
        [ 0.2368, -1.3225, -0.2787,  ...,  0.1614,  0.1536, -0.7170]])
Paramet

        0.2336, 0.1439, 0.0950, 0.2896, 0.1848, 0.2499, 0.1695, 0.1440])
Parameter containing:
tensor([[-0.0334,  0.0786, -0.0425,  ...,  0.0271, -0.0703, -0.0053],
        [-0.0679,  0.0913,  0.0067,  ..., -0.1445,  0.0295,  0.0542],
        [ 0.0737,  0.1108, -0.0364,  ...,  0.0187,  0.0061, -0.0544],
        ...,
        [ 0.0067, -0.0020, -0.0136,  ...,  0.0216, -0.0055, -0.0054],
        [ 0.0452, -0.0059, -0.0121,  ...,  0.0142, -0.0172,  0.0884],
        [ 0.0140,  0.0747,  0.0247,  ..., -0.0232,  0.0510, -0.0415]],
       requires_grad=True)
Parameter containing:
tensor([[ 1.3047, -0.5469, -0.8750,  ..., -0.4375, -1.1953, -0.1855],
        [ 0.4043,  1.0312, -0.0255,  ..., -0.6992,  0.0432,  0.4492],
        [ 0.1875, -0.6680,  0.0991,  ...,  0.1299,  0.3047,  0.6406],
        ...,
        [ 0.5742,  1.0547,  0.0640,  ...,  0.3164, -0.1943,  0.3574],
        [ 0.4082,  0.5430,  0.3086,  ...,  0.3789,  0.4316, -0.2158],
        [-0.2324, -0.1787, -0.0938,  ...,  0.3691, -0.1943,

       requires_grad=True)
Parameter containing:
tensor([[-0.0591,  0.0427,  0.0588,  ..., -0.0369, -0.0288, -0.0474],
        [ 0.0684,  0.0106, -0.0261,  ...,  0.0317,  0.0540, -0.0276],
        [-0.1211,  0.0080, -0.0786,  ...,  0.1816,  0.0117, -0.0178],
        ...,
        [ 0.0747,  0.0123,  0.1113,  ..., -0.0225, -0.0264, -0.0107],
        [ 0.0635, -0.0374,  0.0461,  ...,  0.0588,  0.0327, -0.0461],
        [-0.0325, -0.0688, -0.0557,  ...,  0.0288, -0.0417,  0.0140]],
       requires_grad=True)
Parameter containing:
tensor([[ 0.0840, -0.0062, -0.0547,  ..., -0.2275, -0.3301,  0.1289],
        [ 0.7344,  0.0840, -0.1641,  ...,  0.1592, -0.1768, -0.1348],
        [-0.8242,  0.0152, -0.3457,  ...,  0.1621, -0.4336,  0.0684],
        ...,
        [-0.1040,  0.2471, -0.3164,  ..., -0.3340, -0.2832,  0.6250],
        [-0.1504,  0.0092,  0.6133,  ..., -0.6992,  0.5508,  0.5039],
        [-0.1934,  0.5742,  0.5469,  ..., -0.1445, -0.3984, -0.8242]],
       requires_grad=True)
Paramet

       requires_grad=True)
Parameter containing:
tensor([[-0.0243,  0.0659, -0.0142,  ...,  0.0106,  0.0884, -0.0036],
        [-0.0095, -0.0098, -0.0559,  ..., -0.0225,  0.0518, -0.1055],
        [ 0.0315, -0.0791,  0.0654,  ...,  0.0471,  0.0240, -0.0140],
        ...,
        [-0.0315,  0.0505, -0.0311,  ..., -0.0120,  0.0615,  0.0143],
        [-0.0723,  0.1211, -0.0232,  ..., -0.0006,  0.0016,  0.0474],
        [ 0.0240,  0.0131, -0.0118,  ...,  0.0150,  0.0459,  0.0442]],
       requires_grad=True)
Parameter containing:
tensor([[ 0.5547, -0.3164,  0.3027,  ...,  0.0444,  0.2012, -0.5898],
        [ 0.8906, -0.2246, -0.0270,  ...,  0.5977, -0.3281,  0.2812],
        [ 0.3164, -0.5859, -0.7344,  ...,  0.0908, -0.4707, -0.4062],
        ...,
        [ 0.0540,  0.0796,  0.2275,  ...,  0.0272,  0.0172,  0.1367],
        [ 0.2891, -0.1729,  0.0206,  ...,  0.5781,  0.1309,  0.1436],
        [-0.0801,  0.4785,  0.1533,  ..., -0.0322,  0.2061, -0.2119]],
       requires_grad=True)
Paramet

       requires_grad=True)
Parameter containing:
tensor([ 0.0811,  0.1846,  0.0737,  0.1138,  0.0605,  0.0952,  0.1152,  0.0664,
         0.0649,  0.1187,  0.1016,  0.0403,  0.0957,  0.1016,  0.0786,  0.0835,
         0.1016,  0.0820,  0.1030,  0.0879,  0.0806,  0.0894,  0.1108,  0.0330,
         0.1094,  0.0698,  0.1133,  0.0527,  0.1030,  0.1055,  0.0908,  0.0781,
         0.1133,  0.0972,  0.0894,  0.0693,  0.0981,  0.1011,  0.1064,  0.0757,
         0.0957,  0.0967,  0.0530,  0.1206,  0.0957,  0.0845,  0.0286,  0.0781,
         0.0977,  0.0952,  0.1030,  0.1143,  0.0894,  0.0527,  0.0781,  0.1094,
         0.0962,  0.1147,  0.0981,  0.0898,  0.1211,  0.0840,  0.1328,  0.1099,
         0.1299,  0.1035,  0.0928,  0.1113,  0.0576,  0.1152,  0.1318,  0.0830,
         0.1030,  0.0933,  0.0908,  0.0977,  0.1064,  0.0732,  0.0962,  0.0718,
         0.0869,  0.1084,  0.0933,  0.1138,  0.0879,  0.0815,  0.0854,  0.0664,
         0.1069,  0.1621,  0.0947,  0.0762,  0.0791,  0.1162,  0.1001, 

       requires_grad=True)
Parameter containing:
tensor([[-3.9795e-02,  8.7280e-03,  7.3242e-03,  ..., -8.9722e-03,
         -5.6763e-03, -3.8086e-02],
        [-6.0547e-02, -1.9653e-02,  7.8125e-03,  ...,  7.8735e-03,
         -7.2266e-02,  2.3315e-02],
        [-3.2959e-02, -1.2329e-02, -5.4199e-02,  ...,  3.1250e-02,
          8.6060e-03, -5.6396e-02],
        ...,
        [ 1.6309e-01, -3.9062e-02,  1.2878e-02,  ...,  2.6367e-02,
         -6.8848e-02,  6.2988e-02],
        [ 9.1797e-02,  5.0293e-02, -1.7700e-02,  ...,  5.3955e-02,
         -1.7881e-05, -2.7832e-02],
        [ 2.8198e-02,  3.1494e-02,  1.1816e-01,  ..., -5.6885e-02,
         -3.8330e-02,  2.7954e-02]], requires_grad=True)
Parameter containing:
tensor([[ 8.8672e-01,  2.5195e-01, -2.5391e-01,  ...,  1.2256e-01,
         -1.3770e-01,  1.3477e-01],
        [-3.7354e-02, -1.3965e-01, -4.4141e-01,  ..., -5.4297e-01,
         -8.5547e-01,  1.1406e+00],
        [ 1.4973e-04,  1.1414e-02, -3.0664e-01,  ...,  3.4180e-01,
    

          1.5527e-01, -3.5156e-02]], requires_grad=True)
Parameter containing:
tensor([[-0.4199,  0.9492,  0.5469,  ..., -1.0391,  0.3672, -0.1084],
        [-0.3887, -1.1797, -1.3359,  ..., -1.1875,  0.5547,  0.6133],
        [ 1.7266, -1.1562,  0.7383,  ...,  0.4180, -1.9297, -0.7969],
        ...,
        [-0.5312,  0.1924, -0.2021,  ..., -0.7305, -1.2734,  0.4551],
        [ 0.7578, -1.3594,  0.2598,  ...,  0.5352, -0.1611,  1.5625],
        [ 0.9688,  0.7344, -0.8281,  ...,  0.7227, -0.7227,  0.4414]],
       requires_grad=True)
Parameter containing:
tensor([[ 0.1699, -1.6328, -0.5273,  ..., -0.2090, -1.3828, -0.0266],
        [ 0.5820, -0.6562, -2.0469,  ...,  0.8164,  0.1270, -0.3457],
        [-0.2500, -0.6094,  0.4844,  ...,  0.9336, -0.9883,  0.3457],
        ...,
        [-0.2852,  1.0703,  0.3086,  ..., -0.2734,  0.8164, -0.1738],
        [-0.4082, -0.2812, -0.3203,  ..., -0.2734, -0.1562, -0.1982],
        [ 0.3359,  0.3848,  0.1533,  ..., -1.0078, -0.5508,  0.7305]],
    

       requires_grad=True)
Parameter containing:
tensor([[ 0.0223,  0.0608,  0.0003,  ...,  0.0398,  0.0342, -0.0016],
        [ 0.0649,  0.0972, -0.0131,  ..., -0.0061,  0.0254,  0.0559],
        [ 0.0253,  0.0425,  0.0349,  ...,  0.0898,  0.1221, -0.0164],
        ...,
        [ 0.0378, -0.0596,  0.0110,  ..., -0.0233, -0.0046, -0.0923],
        [ 0.0457,  0.0449,  0.0029,  ..., -0.0243,  0.0562,  0.0674],
        [ 0.0162, -0.0269, -0.0640,  ..., -0.0452,  0.0376,  0.0908]],
       requires_grad=True)
Parameter containing:
tensor([[-0.4570,  0.8867,  0.2070,  ...,  0.0178,  0.9297,  0.7891],
        [ 0.6211, -0.1104,  0.1836,  ...,  0.2734,  0.3848,  0.2559],
        [-0.3633,  0.4883,  0.6406,  ..., -0.9688,  0.8516, -0.3750],
        ...,
        [-0.4902, -0.1260,  0.5625,  ..., -0.3555,  0.7227,  0.0947],
        [-0.6797, -0.2832, -0.3184,  ...,  0.0869, -0.4961,  0.1147],
        [ 0.1650,  0.3027,  0.4434,  ..., -1.1797, -0.0474,  0.0330]],
       requires_grad=True)
Paramet

       requires_grad=True)
Parameter containing:
tensor([[-0.6367,  0.5625,  0.0347,  ...,  0.6953, -0.7461,  0.2969],
        [ 0.7617,  1.0625, -0.3438,  ..., -1.1719,  1.5625,  1.3203],
        [-0.3027, -0.9961,  0.0045,  ...,  0.6055, -1.5469,  0.1367],
        ...,
        [-1.0859,  0.0698,  2.5938,  ..., -0.0947,  0.3203, -0.0811],
        [ 2.7344,  0.8516, -0.7500,  ..., -0.4766, -1.6016, -0.8359],
        [-0.1836, -0.0879,  2.1875,  ..., -1.1328,  0.2754, -0.1748]],
       requires_grad=True)
Parameter containing:
tensor([[ 4.2188e-01,  3.0664e-01,  1.0742e-01,  ..., -8.7891e-01,
          2.2031e+00, -1.1169e-02],
        [ 3.7305e-01,  1.1172e+00,  8.3594e-01,  ...,  2.1406e+00,
          8.4375e-01,  3.3984e-01],
        [ 2.2363e-01, -6.6016e-01,  1.0000e+00,  ...,  6.6797e-01,
         -6.2891e-01,  1.8984e+00],
        ...,
        [-3.6523e-01, -9.8047e-01,  3.0469e-01,  ...,  4.6875e-01,
          1.5391e+00, -1.3672e-01],
        [ 6.6406e-01,  1.4114e-03, -7.2266e

       requires_grad=True)
Parameter containing:
tensor([[ 0.0742, -0.0068, -0.0020,  ...,  0.0723,  0.0167,  0.0109],
        [ 0.0138,  0.0017,  0.0058,  ...,  0.0718, -0.1357, -0.0369],
        [-0.0781, -0.0447,  0.0253,  ..., -0.0024,  0.0684, -0.0415],
        ...,
        [ 0.0364, -0.0840, -0.0179,  ..., -0.0549,  0.0757, -0.0086],
        [-0.0776,  0.0116, -0.0005,  ...,  0.0053,  0.0391,  0.0161],
        [ 0.0219, -0.0393, -0.0282,  ...,  0.0457, -0.0275,  0.0566]],
       requires_grad=True)
Parameter containing:
tensor([[-0.2197, -0.0508, -0.0781,  ..., -0.1035,  0.2119, -0.2344],
        [ 1.1406,  0.0923,  0.1235,  ..., -0.6406, -0.5977,  0.1201],
        [ 0.2256,  0.2832,  0.3203,  ..., -0.3223, -0.4023,  0.2734],
        ...,
        [ 0.2656,  0.3984, -0.0320,  ..., -0.1514,  0.3301,  0.4941],
        [-0.3809, -0.2480,  1.0312,  ..., -0.0344, -0.2637, -0.2393],
        [ 0.1396,  0.0513,  0.3789,  ..., -1.1719, -0.0554, -0.1973]],
       requires_grad=True)
Paramet

       requires_grad=True)
Parameter containing:
tensor([[-0.1777,  1.4141,  0.1826,  ..., -0.0104,  1.1250,  0.4629],
        [ 0.8438,  0.8867, -0.2373,  ..., -1.2031, -1.0938, -0.3652],
        [-0.8438, -0.0188,  0.1592,  ..., -1.9531,  0.0835,  1.2969],
        ...,
        [ 2.1094, -2.2344, -0.9727,  ...,  1.8047,  0.4043, -1.1562],
        [ 0.1069,  2.2344,  0.9297,  ..., -0.6016,  0.4609, -1.3047],
        [-0.9297,  0.1069, -0.9102,  ..., -0.5742,  0.5000, -0.4688]],
       requires_grad=True)
Parameter containing:
tensor([ 0.0645,  0.0859,  0.0610,  0.0806,  0.0532,  0.0757,  0.1016,  0.0688,
         0.0586,  0.0583,  0.0796,  0.0547,  0.0645,  0.0703,  0.0574,  0.0625,
         0.0762,  0.0659,  0.0718,  0.0806,  0.0618,  0.0562,  0.0679,  0.0430,
         0.0786,  0.0674,  0.0664,  0.0767,  0.0713,  0.0913,  0.0630, -0.0564,
         0.0913,  0.0801,  0.1040,  0.0635,  0.0640,  0.0645,  0.0791,  0.0635,
         0.0588,  0.0635,  0.0679,  0.0811,  0.0640,  0.0581,  0.034

       requires_grad=True)
Parameter containing:
tensor([[-0.0366, -0.0088, -0.0591,  ..., -0.0659,  0.0630, -0.0315],
        [ 0.0344,  0.0339, -0.0007,  ...,  0.0688, -0.0698,  0.0349],
        [-0.0530,  0.0162, -0.0216,  ...,  0.0059,  0.0027,  0.0063],
        ...,
        [ 0.0034,  0.0034, -0.0012,  ..., -0.0036, -0.0179, -0.0153],
        [ 0.0200,  0.0640,  0.0193,  ...,  0.0119,  0.0747, -0.0064],
        [-0.0197, -0.0309,  0.0437,  ..., -0.0488, -0.0277,  0.0171]],
       requires_grad=True)
Parameter containing:
tensor([[ 0.6445, -0.0413, -0.9297,  ..., -0.2578,  0.4941,  0.8125],
        [ 0.5430, -1.0391, -0.7812,  ..., -0.2100, -0.1924,  0.7422],
        [-0.2256,  0.3418,  0.5352,  ..., -0.6914, -0.1719,  0.2188],
        ...,
        [-0.4258, -0.2061, -0.2061,  ...,  0.0830, -0.7617, -0.1895],
        [-0.3223, -0.1943,  0.0835,  ...,  0.0089, -0.9023,  0.3223],
        [ 0.1426, -0.7266, -0.0588,  ...,  0.3281, -0.0091, -0.2402]],
       requires_grad=True)
Paramet

       requires_grad=True)
Parameter containing:
tensor([[ 2.4688,  0.2871, -0.2354,  ..., -2.1562,  3.0312,  0.2041],
        [-1.9766,  0.6211,  0.4238,  ...,  1.4766,  0.4707, -0.6172],
        [ 1.4219,  0.7695, -0.8711,  ..., -1.1641,  2.6094, -1.7969],
        ...,
        [-0.9062, -1.0547, -1.8203,  ...,  0.1182, -1.6406, -1.6094],
        [-0.6133,  0.2061, -0.4434,  ..., -0.6992, -0.6250, -0.1543],
        [ 0.4160, -0.3906,  0.2041,  ...,  0.1973,  0.3203, -0.1943]],
       requires_grad=True)
Parameter containing:
tensor([ 0.1592,  0.1152,  0.1270,  0.1943,  0.0396,  0.1553,  0.2012,  0.1221,
         0.1006,  0.1226,  0.1475,  0.1436,  0.1357,  0.1504, -0.0747,  0.1113,
         0.1514,  0.1641,  0.1523,  0.1475,  0.1387,  0.1060,  0.1245,  0.0408,
         0.1514,  0.1543,  0.1260,  0.1167,  0.1396,  0.1895,  0.1289,  0.1104,
         0.1924,  0.1865,  0.1855,  0.1094,  0.1309,  0.1309,  0.2061,  0.1226,
         0.1426,  0.1133,  0.1875,  0.1748,  0.1436,  0.1201,  0.043

       requires_grad=True)
Parameter containing:
tensor([ 1.4708e-01,  3.4447e-01,  1.4789e-01,  1.2291e-01, -1.1636e-03,
         1.0635e-01,  2.9707e-01,  2.1828e-01,  5.5557e-01,  1.1080e-01,
         1.1728e-01,  3.7152e-01,  1.1633e-01,  9.9032e-02,  1.7351e+00,
         1.2763e-01,  1.0339e-01,  1.6823e-01,  1.3493e-01,  1.2870e-01,
         1.4680e-01,  9.4393e-02,  1.5034e-01,  1.1528e+00,  1.0193e-01,
         1.8226e-01,  1.0032e-01,  5.7228e-01,  1.1534e-01,  2.0333e-01,
         1.1902e-01,  2.1540e-01,  1.6978e-01,  1.2572e-01,  2.8245e-01,
         2.0308e-01,  1.3183e-01,  1.1769e-01,  1.5051e-01,  1.5472e-01,
         1.2448e-01,  9.9111e-02,  2.3669e-01,  1.2650e-01,  1.6262e-01,
         1.3315e-01,  2.3197e-01,  1.5795e-01,  1.7246e-01,  1.1895e-01,
         1.4363e-01,  1.4300e-01,  2.0426e-01,  4.6827e-01,  3.4454e-01,
         1.3522e-01,  1.4810e-01,  1.5063e-01,  1.2105e-01,  2.6712e-01,
         8.9629e-02,  2.4379e-01,  2.2667e-01,  1.0368e-01,  1.1707e-01,
  

        [-0.1107, -0.0712, -0.1420,  ..., -0.1076, -0.0570, -0.0725]])
Parameter containing:
tensor([[ 0.1419, -0.0947, -0.1654,  ..., -0.1171, -0.1268, -0.1406],
        [ 0.1231, -0.0094, -0.0507,  ...,  0.0768,  0.1110,  0.0532],
        [-0.0250,  0.1136, -0.0637,  ...,  0.0536,  0.0568, -0.1511],
        ...,
        [-0.1523, -0.1015,  0.0489,  ...,  0.0984, -0.0475, -0.1243],
        [ 0.1588, -0.0793, -0.1605,  ..., -0.0428,  0.1615,  0.0154],
        [-0.1285,  0.1081,  0.0423,  ...,  0.1267,  0.0056, -0.0586]],
       requires_grad=True)


In [47]:
fixed_strings = []

        # test_dataloader = DataLoader(test_set, batch_size=args.batch_size)
#test_dataloader = DataLoader(train_set, batch_size=args.batch_size)
for i in range(2):

            # category = category.to(model.master_ctx)
            # tokenized = tokenized.to(model.master_ctx)
            # mask = mask.to(model.master_ctx)

            # model.train()
            # x, z, mu, logvar = model(condition, tokenized, mask, label)
            # loss = x - 1
            # loss.mean().backward()
            # for name, param in model.named_parameters():
            #    if param.grad is None:
            #        print(name)

            # continue
    with torch.no_grad():
        sampled_z = torch.rand((1, 32))
        print(sampled_z)
        fixed_tokens = generate(
                    model,
                    starter_tokens=[model.config.decoder_start_token_id],
                    #input_ids=tokenized,
                    #attention_mask=mask,
                    bos_token_id=tokenizer.bos_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    pad_token_id=tokenizer.pad_token_id,
                    top_p=1,
                    top_k=1,
                    temperature=1.0,
                    output_hidden_states=True,
                    num_beams=20,
                    use_cache=True,
                    sampled_z = sampled_z
                )

    fixed = tokenizer.batch_decode(fixed_tokens, skip_special_tokens=True)
    #original = tokenizer.batch_decode(tokenized, skip_special_tokens=True)

    for f in fixed:

                # print(f"--------\n[CONDITION] {condition}\n[ORIGINAL] {o}\n[FIXED] {f}")
        print(f"--------\n[FIXED] {f}")

tensor([[0.1007, 0.4774, 0.0052, 0.1121, 0.4071, 0.2384, 0.1197, 0.3160, 0.0862,
         0.2766, 0.6118, 0.1288, 0.1038, 0.3147, 0.9559, 0.4759, 0.5058, 0.3884,
         0.7934, 0.5538, 0.1048, 0.6467, 0.8006, 0.6423, 0.3725, 0.6044, 0.8789,
         0.8351, 0.5709, 0.6088, 0.5608, 0.3052]])
--------
[FIXED] a a slew of a slew of a slew of a a a a a a a a a a a a a a a a a slew of a slew of a slew of.......
tensor([[0.8928, 0.3789, 0.4691, 0.5059, 0.6182, 0.5168, 0.6549, 0.4933, 0.3670,
         0.5736, 0.4139, 0.4368, 0.7602, 0.3926, 0.4630, 0.5737, 0.6733, 0.0875,
         0.2513, 0.5212, 0.7130, 0.7029, 0.6795, 0.8861, 0.0129, 0.5521, 0.2889,
         0.4131, 0.9397, 0.9472, 0.9823, 0.3961]])
--------
[FIXED] a a slew of a slew of a slew of a a a a a a a a a a a a a a a a a slew of a slew of a slew of...... 


In [21]:
dataset = datasets.get("snli")
if not dataset:
    raise Exception("Wrong dataset.")

dataset_class = dataset["dataset_class"]
out_dim = 512
train_set = dataset_class(dataset["train_file"], tokenizer, out_dim)
valid_set = dataset_class(dataset["validate_file"], tokenizer, out_dim)

In [22]:
train_dataloader = DataLoader(
        train_set, batch_size=32, num_workers=1, shuffle=True
    )

In [23]:
train_indices = list(range(len(train_set)))
np.random.shuffle(train_indices)

In [24]:
train_dataloader = DataLoader(
        train_set, batch_size=32, num_workers=1, sampler=SubsetRandomSampler(train_indices[:320]),pin_memory=True
    )
valid_dataloader = DataLoader(
        valid_set, batch_size=32, num_workers=1    )

In [50]:
calc_mi(model, valid_dataloader)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
-46.829833984375
-86.28955078125
-91.126953125
-103.961669921875
-93.4228515625
-80.132568359375
-64.1295166015625
-32.8935546875
-14.37744140625
44.2152099609375


0.13817253112792968

In [16]:
def calc_mi(model, test_data_batch):
    mi = 0
    num_examples = 0
    for batch_data in train_dataloader:
        batch_data = batch_data[0]
        batch_size = batch_data.size(0)
        num_examples += batch_size
        mutual_info = calc_mi_q(batch_data)
        mi += mutual_info * batch_size
        print(mi)

    return mi / num_examples

In [17]:
def calc_mi_q(x):
        """Approximate the mutual information between x and z
        I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z))
        Returns: Float
        """

        # [x_batch, nz]
        encoder_outputs = model.t5.run_encoder(input_ids=x,output_hidden_states=True)
        pooled = model.t5.pool(encoder_outputs.hidden_states)
        z, mu, logvar = model.t5.calculate_latent(pooled)
        z = z.unsqueeze(1)
        x_batch, nz = mu.size()

        # E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1)
        neg_entropy = (-0.5 * nz * math.log(2 * math.pi)- 0.5 * (1 + logvar).sum(-1)).mean()


        # [1, x_batch, nz]
        mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0)
        var = logvar.exp()

        # (z_batch, x_batch, nz)
        
        dev = z - mu

        
        # (z_batch, x_batch)
        log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
            0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))

        # log q(z): aggregate posterior
        # [z_batch]
        log_qz = log_sum_exp(log_density, dim=1) - math.log(x_batch)

        return (neg_entropy - log_qz.mean(-1)).item()

In [18]:
def log_sum_exp(value, dim=None, keepdim=False):
    """Numerically stable implementation of the operation
    value.exp().sum(dim, keepdim).log()
    """
    if dim is not None:
        m, _ = torch.max(value, dim=dim, keepdim=True)
        value0 = value - m
        if keepdim is False:
            m = m.squeeze(dim)
        return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim))
    else:
        m = torch.max(value)
        sum_exp = torch.sum(torch.exp(value - m))
        return m + torch.log(sum_exp)

In [49]:
model.t5.pooling_strategy = "max"

In [16]:
for batch_data in next(iter(train_dataloader)):
    print(len(batch_data))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
32
32
32
