In [24]:
from ola_cb import * 
from ola_RNN import * 

import os, time, copy, math, re, json, pickle, random
import numpy as np
import pandas as pd

import torch, torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.ticker as ticker

from functools import partial 

cuda_available = torch.cuda.is_available()
device = torch.device("cuda:0" if cuda_available else "cpu")
print(f'''using device {device}''')

path = !pwd
path = path[0]
print(path)

using device cuda:0
/home/r2/Documents/RNNexp


In [25]:
def one_gru_batch(xb,yb,cb):
    pred, cb.learn.hidden, loss = cb.learn.model.batch_forward(xb,yb,cb.learn.hidden,learn.loss_fn)
    if not cb.after_loss(loss): return    
    loss.backward()
    if not cb.after_backward(): return 
    cb.learn.opt.step()
    if not cb.after_step(): return
    cb.learn.opt.zero_grad()

def fit_gru(epoches, learn, cb=None, itters=math.inf):
    hidden = learn.model.initHidden(learn.data.train_dl.bs)
    if not cb.begin_fit(learn):           return 
    for epoch in range(epoches):
        if not cb.begin_epoch(epoch):     return             
        for xb, yb in iter(learn.data.train_dl):   
            if not cb.begin_batch(xb,yb): return  
            one_gru_batch(xb,yb,cb)
            if not cb.begin_validate():   return     
            if cb.do_stop():              break 
        if not cb.after_epoch():          return
    if not cb.after_fit():                return 
    return 

In [26]:
class Learner():
    def __init__(self, model, loss_fn, opt, data, lr):
        self.model, self.opt, self.loss_fn, self.data = model, opt, loss_fn, data
        self._lr     = opt.param_groups[0]['lr']
        self.hidden  = None    
        self.stats   = Struct()
        self.stats.valid_loss = []
        self.stats.train_loss = []          
        self.n_epochs = 0.
        self.n_iters  = 0
        
    @property
    def lr(self):
        return self._lr
    
    @lr.setter
    def lr(self,lr):
        self._lr = lr
        for param_group in self.opt.param_groups:
            param_group['lr'] = lr        
            
    def one_batch(self, i, xb, yb):
        try:
            self.iter = i 
            self.xb,self.yb = xb,yb;                       self('begin_batch')
            self.pred = self.model(self.xb);               self('after_pred')
            self.loss = self.loss_fn(self.pred, self.yb);  self('after_loss')
            if not self.in_train: return
            self.loss.backwards();                         self('after_backward')
            self.opt.step();                               self('after_step')
            self.opt.zero_grad();
        except CancelBatchException:                       self('after_cancel_ batch')
        finally:                                           self('after_batch')
            

In [27]:
class ParamScheduler(Callback):
    _order=5
    def __init__(self, pname, sched_func): self.pname,self.sched_func = pname,sched_func

    def begin_fit(self,learn):
        super().begin_fit(learn)
        return True        
        
    def set_param(self):
        for pg in self.learn.opt.param_groups:
            pg[self.pname] = self.sched_func(self.learn.n_epochs)
            
    def begin_batch(self,xb,yb): 
        if self.learn.in_train: self.set_param()

In [28]:
class StatsCallback(Callback):
    _order = 10
    
    def begin_fit(self,learn):
        super().begin_fit(learn)
        return True
    
    def begin_validate(self):
        if self.learn.n_iters%100 == 0:
            self.learn.in_train = False            
            self.learn.stats.valid_loss.append(get_valid_rnn(self.learn,itters=30))                      
        return True
    

In [29]:
class GRU(nn.Module):
    def __init__(self, in_sz, hd_sz):
        super(GRU,self).__init__()
        self.in_sz = in_sz
        self.hd_sz = hd_sz
        
        self.h_lin = nn.Linear(self.hd_sz,3*self.hd_sz)
        self.x_lin = nn.Linear(self.in_sz,3*self.hd_sz)        
        
        self.up_sig = nn.Sigmoid()
        self.re_sig = nn.Sigmoid()
            
        self.o1      = nn.Linear(self.hd_sz+self.in_sz,self.in_sz)  

        self.softmax = nn.LogSoftmax(dim=1)            
            
    def forward(self,input,hidden):        
        x = self.x_lin(input)        
        h = self.h_lin(hidden)        
        x_u,x_r,x_n = x.chunk(3,1)
        h_u,h_r,h_n = h.chunk(3,1)
        update_gate = self.up_sig(x_u+h_u)        
        reset_gate  = self.re_sig(x_r+h_r)
        new_gate    = torch.tanh(x_n + reset_gate * h_n)         
        h_new       = update_gate * hidden + (1 - update_gate) * new_gate 
        
        combined   = torch.cat((input,h_new),1)
        combined   = self.o1(combined)

        prediction = self.softmax(combined)
        
        return prediction, h_new
    
    def batch_forward(self,xb,yb,hidden,loss_fn):
        self.train()
        if xb[0,0,1].item() == 1: hidden = self.initHidden(xb.shape[0])                   
        loss = 0 
        for char in range(xb.shape[1]):
            x,y           = xb[:,char],yb[:,char]
            x,y,hidden    = unpad(x,y,hidden)
            if x.shape[0] == 0: break
            output,hidden = self.forward(x,hidden)
            loss += loss_fn(output,y)    
        return output,hidden.detach(),loss/(char+1)

    
    def initHidden(self, bs):
        return cuda(torch.zeros(bs,self.hd_sz))

In [30]:
bs  = 20
sql = 30 
lr  = 0.0005 

sched = combine_scheds([0.3, 0.7], [sched_cos(0.3, 0.6), sched_cos(0.6, 0.2)]) 
data          = pp_trumpdata(path+"/data/trump/", [0.9,0.95], bs)
data.train_dl = TweetDataLoader(data,data.train.tweets,bs,sql,shuffle=True)
data.valid_dl = TweetDataLoader(data,data.valid.tweets,bs,sql,shuffle=False)

model  = cuda(GRU(len(data.decoder), 150))
opt    = optim.RMSprop(model.parameters(), lr)

learn  = Learner(model,  nn.NLLLoss(), opt , data, lr=lr)
# , , ParamScheduler('lr',sched)]) StatsCallback()
cbs    = CallbackHandler([CounterCallback(500),StatsCallback()])

1
10


In [31]:
fit_gru(1,learn,cbs)

0
True
1
True
2
True
3
True
4
True
5
True
6
True
7
True
8
True
9
True
10
True
11
True
12
True
13
True
14
True
15
True
16
True
17
True
18
True
19
True
20
True
21
True
22
True
23
True
24
True
25
True
26
True
27
True
28
True
29
True
30
True
31
True
32
True
33
True
34
True
35
True
36
True
37
True
38
True
39
True
40
True
41
True
42
True
43
True
44
True
45
True
46
True
47
True
48
True
49
True
50
True
51
True
52
True
53
True
54
True
55
True
56
True
57
True
58
True
59
True
60
True
61
True
62
True
63
True
64
True
65
True
66
True
67
True
68
True
69
True
70
True
71
True
72
True
73
True
74
True
75
True
76
True
77
True
78
True
79
True
80
True
81
True
82
True
83
True
84
True
85
True
86
True
87
True
88
True
89
True
90
True
91
True
92
True
93
True
94
True
95
True
96
True
97
True
98
True
99
True
getting validation
100
True
101
True
102
True
103
True
104
True
105
True
106
True
107
True
108
True
109
True
110
True
111
True
112
True
113
True
114
True
115
True
116
True
117
True
118
True
119
True
120
True
12

KeyboardInterrupt: 

In [None]:
plt.figure()
plt.plot([x for x in range(len(learn.stats.valid_loss))],learn.stats.valid_loss,label='vloss')
plt.legend()      

In [None]:
annealings = "NO LINEAR COS EXP".split()

a = torch.arange(0, 100)
p = torch.linspace(0.01,1,100)

fns = [sched_no, sched_lin, sched_cos, sched_exp]
for fn, t in zip(fns, annealings):
    f = fn(2, 1e-2)
    plt.plot(a, [f(o) for o in p], label=t)
plt.legend();

In [None]:
plt.plot(a, [sched(o) for o in p])