In [3]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [4]:
import numpy as np
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import make_grid , save_image

In [15]:
from constants import *

In [5]:
import matplotlib.pyplot as plt

def show_adn_save(file_name,img):
    npimg = np.transpose(img.numpy(),(1,2,0))
    f = "./%s.png" % file_name
    plt.imshow(npimg)
    plt.imsave(f,npimg)

In [61]:
class RBM(nn.Module):
    def __init__(self,
                 n_vis=784,
                 n_hin=500,
                 k=5):
        super(RBM, self).__init__()
#         self.embedding = nn.Embedding()
        self.W = nn.Parameter(torch.randn(n_hin,n_vis)*1e-2)
        self.v_bias = nn.Parameter(torch.zeros(n_vis))
        self.h_bias = nn.Parameter(torch.zeros(n_hin))
        self.k = k
    
    def sample_from_p(self,p):
        return F.relu(torch.sign(p - Variable(torch.rand(p.size()))))
    
    def v_to_h(self,v):
        p_h = F.sigmoid(F.linear(v,self.W,self.h_bias))
        sample_h = self.sample_from_p(p_h)
        return p_h,sample_h
    
    def h_to_v(self,h):
        p_v = F.sigmoid(F.linear(h,self.W.t(),self.v_bias))
        sample_v = self.sample_from_p(p_v)
        return p_v,sample_v
        
    def forward(self,v):
        pre_h1,h1 = self.v_to_h(v)
        
        h_ = h1
        for _ in range(self.k):
            pre_v_,v_ = self.h_to_v(h_)
            pre_h_,h_ = self.v_to_h(v_)
        
        return v,v_
    
    def free_energy(self,v):
        vbias_term = v.mv(self.v_bias)
        wx_b = F.linear(v,self.W,self.h_bias)
        hidden_term = wx_b.exp().add(1).log().sum(1)
        return (-hidden_term - vbias_term).mean()

### Old way of pulling out corpus

In [16]:
with open(f'{OUT_DIR}/concat_corpus.utf') as f:
    train_contents = f.read()

In [147]:
train_contents[0:10]

'їQ\x11v!ÆÈæ\x80l'

In [148]:
train_contents[10]

'Æ'

### New hotness - one hot encoded vectors

In [79]:
import h5py    
import numpy as np    
import json
concat_h5 = h5py.File(f'{OUT_DIR}/concat_corpus.h5','r+') 

concat_json = json.load(open(f'{OUT_DIR}/concat_corpus.json', 'rb'))

In [276]:
concat_json['idx_to_token']

{'1': 'ї',
 '10': 'l',
 '100': 'Þ',
 '101': 'L',
 '102': '\x08',
 '103': '\x16',
 '104': '\x98',
 '105': '÷',
 '106': 'ā',
 '107': 'å',
 '108': 'ë',
 '11': 'z',
 '12': '¾',
 '13': 'g',
 '14': 'K',
 '15': '\x84',
 '16': '\x91',
 '17': 'y',
 '18': '?',
 '19': '\x1e',
 '2': 'Q',
 '20': '¹',
 '21': '®',
 '22': '¼',
 '23': 'ú',
 '24': 'c',
 '25': 'ì',
 '26': '\x19',
 '27': '\x0f',
 '28': '\x93',
 '29': 'º',
 '3': '\x11',
 '30': '\x1f',
 '31': '\x03',
 '32': 'B',
 '33': '8',
 '34': '5',
 '35': '"',
 '36': 'e',
 '37': 'Ú',
 '38': '*',
 '39': '\x8d',
 '4': 'v',
 '40': '\x0c',
 '41': 'Ø',
 '42': '£',
 '43': '\x0b',
 '44': '\x8c',
 '45': '²',
 '46': '\x8a',
 '47': 'Z',
 '48': 'ï',
 '49': '\x1a',
 '5': '!',
 '50': 'n',
 '51': 'a',
 '52': '\x7f',
 '53': '\x94',
 '54': '\\',
 '55': 'D',
 '56': 'ó',
 '57': 'ù',
 '58': 'Â',
 '59': 'b',
 '6': 'Æ',
 '60': 'ý',
 '61': 'P',
 '62': 'J',
 '63': 'ß',
 '64': 'ћ',
 '65': 'j',
 '66': 'R',
 '67': '\x05',
 '68': '\x8e',
 '69': 'Ç',
 '7': 'È',
 '70': '\x9c',
 '71

In [92]:
vocab_size = len(concat_json['idx_to_token'])

In [80]:
list(concat_h5.keys())
train = concat_h5['train']

In [138]:
train.shape[0]

273365

In [None]:
1, 2, 3, 4, 5
seq = 3
1, 2, 3 -> 4
2, 3, 4 -> 5

In [146]:
len(c_in_dat[0])

8

In [None]:
batch_size = 64
def music_hot_generator(sequence_length, batch_size):
    for i in range(0, xs.shape[0], batch_size):
        yield i, (xs[i:i+batch_size], y[i:i+batch_size])
        

### One hot encoding

In [88]:
def one_hot(a,c): 
    return np.eye(c)[a]

In [115]:
x_hot = one_hot(train, vocab_size+1)

In [133]:
testi = split_xy(train, 8, 64, vocab_size)

In [135]:
def split_xy(sequence, sequence_length, batch_size):
    c_in_dat = [[sequence[i+j] for i in range(sequence_length)] for j in range(len(sequence)-sequence_length)]
    c_out_dat = [sequence[j+sequence_length] for j in range(len(sequence)-sequence_length)]
    xs = np.stack(c_in_dat, axis=0)
    y = np.stack(c_out_dat)
    
    for i in range(0, xs.shape[0], batch_size):
        yield i, (xs[i:i+batch_size], y[i:i+batch_size])

In [205]:
class MusicDataset(torch.utils.data.Dataset):
    """Face Landmarks dataset."""

    def __init__(self, h5_file, set_type, json_file, sequence_length, root_dir):
        self.concat_h5 = h5py.File(f'{root_dir}/{h5_file}','r+')
        self.dataset = self.concat_h5[set_type]
        self.concat_json = json.load(open(f'{root_dir}/{json_file}', 'rb'))
        self.vocab_size = len(self.concat_json['idx_to_token'])+1
        self.data_length = self.dataset.shape[0]
        self.sequence_length = sequence_length

    def __len__(self):
#         print((self.data_length - self.sequence_length)//10)
        return (self.data_length - self.sequence_length)//10

    def __getitem__(self, idx):
#         print(np.arange(10)[0:8]) # example
#         print(np.arange(10)[8])
        x = self.dataset[idx:idx+sequence_length]
        y = self.dataset[idx+sequence_length]
        
        x_hot = one_hot(x, self.vocab_size)
        return x_hot, y


In [206]:
md = MusicDataset(h5_file='concat_corpus.h5', set_type='train', json_file='concat_corpus.json', sequence_length=8, root_dir=OUT_DIR)

In [207]:
train_loader = torch.utils.data.DataLoader(md,
    batch_size=batch_size)

### Dataset sanity test

In [208]:
train_iter = enumerate(train_loader)

In [240]:
i, (x, y) = next(train_iter)
i2, (x2, y2) = next(train_iter)

In [241]:
md.dataset[:100]

array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10,  6,  6,  6, 11,  3, 12, 13,
        6, 14,  8, 15, 16,  6, 14,  8,  4, 16,  6, 14,  8,  9, 16,  6,  2,
       17, 18, 19,  6,  7, 20, 21, 22,  6, 23, 20, 21, 22,  6, 24, 20, 21,
       22,  6, 25,  3, 26, 27,  6, 28,  8, 29, 30,  6, 23,  8, 29, 30,  6,
       24,  8, 29, 30,  6,  2, 17, 18, 31,  6,  7, 20, 21, 32,  6,  7, 20,
        4,  5,  6,  7, 20,  9, 10,  6, 23, 26, 12, 18,  6, 24, 29],
      dtype=uint8)

In [242]:
a = np.argmax(x[11], axis=1)[:-1]
b = np.argmax(x[10], axis=1)[1:]
np.testing.assert_array_equal(a, b)

In [244]:
a = np.argmax(x2[0], axis=1)[:-1]
b = np.argmax(x[-1], axis=1)[1:]
np.testing.assert_array_equal(a, b)

In [249]:
x.shape

torch.Size([64, 8, 109])

In [250]:
t_shape = x.shape

### Training

In [245]:
rbm = RBM(n_vis=md.vocab_size*md.sequence_length, k=1)

In [246]:
train_op = optim.SGD(rbm.parameters(),0.1)

In [251]:
for epoch in range(10):
    loss_ = []
    for i, (data,target) in enumerate(train_loader):
#         data = Variable(data.view(-1,784))
        if data.shape[0] != train_loader.batch_size:
            print('Mismatched shapes:', i, data.shape)
            break
        data = Variable(data.float().view(data.shape[0], -1))
        sample_data = data.bernoulli()
        
        v,v1 = rbm(sample_data)
        loss = rbm.free_energy(v) - rbm.free_energy(v1)
        loss_.append(loss.data[0])
        train_op.zero_grad()
        loss.backward()
        train_op.step()
        
        if (i % 100 == 0):
            print(f'i: {i}, loss: {np.mean(loss_)}')
    
    print(np.mean(loss_))

27335
-1.12091064453125
-0.5219744691754332
-0.9445724866876555
-0.9403748369692172
-0.4558586444046135
Mismatched shapes: 427 torch.Size([7, 8, 109])
-0.5188454978639124
27335
5.022308349609375
0.28336930983137376
-0.21261384119441853
-0.2402448178921823
0.17332135055427836
Mismatched shapes: 427 torch.Size([7, 8, 109])
0.09997158363217214
27335
3.933074951171875
0.40506011660736385
-0.1473179148204291
-0.17393843438538206
0.19256911432356608
Mismatched shapes: 427 torch.Size([7, 8, 109])
0.10507273618175497
27335
0.907562255859375
0.37040181679300743
-0.25541208751166045
-0.29832676478794645
-0.004735247452657419
Mismatched shapes: 427 torch.Size([7, 8, 109])
-0.10526877767308256
27335
2.853607177734375
-0.0430470079478651
-0.5580532396610697
-0.6672031744770037
-0.3685808823887547
Mismatched shapes: 427 torch.Size([7, 8, 109])
-0.4177213217670521
27335
1.289703369140625
-0.12000705227993502
-0.721827340956351
-0.829992604810138
-0.5182282407384858
Mismatched shapes: 427 torch.Size([

KeyboardInterrupt: 

In [255]:
v.shape

torch.Size([64, 872])

In [256]:
v1.shape

torch.Size([64, 872])

In [263]:
v1.type

<bound method Variable.type of Variable containing:
    0     0     0  ...      0     0     0
    0     0     0  ...      0     0     0
    0     0     0  ...      0     0     0
       ...          ⋱          ...       
    0     0     0  ...      0     0     0
    0     0     0  ...      0     0     0
    0     0     0  ...      0     0     0
[torch.FloatTensor of size 64x872]
>

In [262]:
type(v1.data.numpy())

numpy.ndarray

In [301]:
x = v.data.numpy()[4]

In [302]:
seq_arr = np.argmax(x.reshape(-1, md.vocab_size), axis=1); seq_arr

array([20, 21, 13, 78,  6, 20, 21, 16])

In [None]:
idx2token

In [313]:
idx2token = md.concat_json['idx_to_token']
test = map(lambda x: idx2token[str(x)] if str(x) in idx2token else '', seq_arr)
# test = [idx2token[f'{x}'] for x in seq_arr]; test

In [264]:
def convert_back(x, vocab_size):
    if type(x) != np.ndarray:
        x = x.data.to_numpy()
    xr = x.reshape(-1, vocab_size)
    xmax = np.argmax()

In [None]:
show_adn_save("real",make_grid(v.view(32,1,28,28).data))

In [None]:
show_adn_save("generate",make_grid(v1.view(32,1,28,28).data))