In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
from transformers import AutoTokenizer
from datasets import load_dataset # huggingface datasets
from torchvision import datasets  # torchvision datasets
from torch.autograd import Variable
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import joblib
import seaborn as sns; sns.set(color_codes=True)
sns.set_style("white")

In [None]:
torch.__version__

In [2]:
# Make a cache for optimizers so we can quickly re-compare results without having to re-train
!mkdir -p _cache
cache = joblib.Memory(location='_cache', verbose=0)

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

# L2O Learned Optimizers

In [6]:
class Optimizer_Grad(nn.Module):
    def __init__(self, hidden_size=20):
        super().__init__()
        self.lstm = nn.LSTMCell(1, hidden_size)
        self.lstm2 = nn.LSTMCell(hidden_size, hidden_size)
        self.output = nn.Linear(hidden_size, 1)
        self.hidden_size = hidden_size
        
    def forward(self, input, hidden, cell):
        # Preprocess Optionally?
        hidden0, cell0 = self.lstm(input, (hidden[0], cell[0]))
        hidden1, cell1 = self.lstm2(hidden0, (hidden[1], cell[1]))
        return self.output(hidden1), (hidden0, hidden1), (cell0, cell1)


In [4]:
# Optimizer with Preprocessed input features
class Optimizer_PP(nn.Module):
    def __init__(self, preproc=False, hidden_size=20, preproc_factor=10.0):
        super().__init__()
        self.hidden_size = hidden_size
        if preproc:
            self.lstm = nn.LSTMCell(2, hidden_size)
        else:
            self.lstm = nn.LSTMCell(1, hidden_size)
        self.lstm2 = nn.LSTMCell(hidden_size, hidden_size)
        self.output = nn.Linear(hidden_size, 1)
        self.preproc = preproc
        self.preproc_factor = preproc_factor
        self.preproc_threshold = np.exp(-preproc_factor)
        
    def forward(self, inp, hidden, cell):
        if self.preproc:
            # Implement preproc described in Appendix A
            
            # Note: we do all this work on tensors, which means
            # the gradients won't propagate through inp. This
            # should be ok because the algorithm involves
            # making sure that inp is already detached.
            inp = inp.data
            inp2 = torch.zeros(inp.size()[0], 2).to(device)
            keep_grads = (torch.abs(inp) >= self.preproc_threshold).squeeze()
            inp2[:, 0][keep_grads] = (torch.log(torch.abs(inp[keep_grads]) + 1e-8) / self.preproc_factor).squeeze()
            inp2[:, 1][keep_grads] = torch.sign(inp[keep_grads]).squeeze()
            
            inp2[:, 0][~keep_grads] = -1
            inp2[:, 1][~keep_grads] = (float(np.exp(self.preproc_factor)) * inp[~keep_grads]).squeeze()
            inp = Variable(inp2).to(device)
        hidden0, cell0 = self.lstm(inp, (hidden[0], cell[0]))
        hidden1, cell1 = self.lstm2(hidden0, (hidden[1], cell[1]))
        return self.output(hidden1), (hidden0, hidden1), (cell0, cell1)

In [5]:
def detach_var(v):
    var = Variable(v.data, requires_grad=True).to(device)
    var.retain_grad()
    return var

In [6]:
@cache.cache
def fit_optimizer(target_cls, target_to_opt, preproc=False, unroll=20, optim_it=100, n_epochs=20, n_tests=100, lr=0.001, out_mul=1.0):
    # L2O Optimizer Network
    opt_net = Optimizer_PP(preproc=preproc).to(device)
    # Meta-Optimizer which will optimize the L2O Optimizer Network
    meta_opt = torch.optim.Adam(opt_net.parameters(), lr=lr)
    
    best_net = None
    best_loss = 100000000000000000
    
    for _ in tqdm(range(n_epochs), 'epochs'):
        for _ in tqdm(range(20), 'iterations'):
            do_fit(opt_net, meta_opt, target_cls, target_to_opt, unroll, optim_it, n_epochs, out_mul, should_train=True)
            torch.cuda.empty_cache() # CHECK THIS
        
        loss = (np.mean([
            np.sum(do_fit(opt_net, meta_opt, target_cls, target_to_opt, unroll, optim_it, n_epochs, out_mul, should_train=False))
            for _ in tqdm(range(n_tests), 'tests')
        ]))
        print(loss)
        if loss < best_loss:
            print(best_loss, loss)
            best_loss = loss
            best_net = copy.deepcopy(opt_net.state_dict())
            
    return best_loss, best_net

In [7]:
import gc

In [9]:
def do_fit(opt_net, meta_opt, target_class, target_to_opt, unroll, optim_it, n_epochs, out_mul, should_train=True, return_optimizee=False):
    print("do_fit")
    if should_train:
        opt_net.train()
    else:
        opt_net.eval()
        unroll = 1

    target = target_class(training=should_train)
    optimizee = target_to_opt().to(device)
    n_params = 0
    for p in optimizee.parameters():
        n_params += int(np.prod(p.size()))
    hidden_states = [Variable(torch.zeros(n_params, opt_net.hidden_size)).to(device) for _ in range(2)]
    cell_states = [Variable(torch.zeros(n_params, opt_net.hidden_size)).to(device) for _ in range(2)]
    all_losses_ever = []
    if should_train:
        meta_opt.zero_grad()
    all_losses = None
    for iteration in range(1, optim_it + 1):
        loss = optimizee(target)
                    
        if all_losses is None:
            all_losses = loss
        else:
            all_losses += loss
        
        all_losses_ever.append(loss.data.cpu().numpy())
        loss.backward(retain_graph=should_train)

        offset = 0
        result_params = {}
        hidden_states2 = [Variable(torch.zeros(n_params, opt_net.hidden_size)).to(device) for _ in range(2)]
        cell_states2 = [Variable(torch.zeros(n_params, opt_net.hidden_size)).to(device) for _ in range(2)]
        hidden_states2 = []
        cell_states2 = []
        for name, p in optimizee.all_named_parameters():
            cur_sz = int(np.prod(p.size()))
            # We do this so the gradients are disconnected from the graph but we still get
            # gradients from the rest
            gradients = detach_var(p.grad.view(cur_sz, 1))

            updates, new_hidden, new_cell = opt_net(
                gradients,
                [h[offset:offset+cur_sz] for h in hidden_states],
                [c[offset:offset+cur_sz] for c in cell_states]
            )
            for i in range(len(new_hidden)):
                hidden_states2[i][offset:offset+cur_sz] = new_hidden[i]
                cell_states2[i][offset:offset+cur_sz] = new_cell[i]
            # Update the specific hidden_states and cell_states
            # for i in range(len(new_hidden)):
            #     # indices need to be offset into the actual "hidden_states" and "cell_states"
            #     # offset_indices = indices + offset
            #     # hidden_states[i][offset_indices] = new_hidden[i]
            #     # cell_states[i][offset_indices] = new_cell[i]
            # hidden_states2.append((indices, new_hidden))
            # cell_states2.append((indices, new_cell))
            
            # Apply updates. For non-indexed params, just use the original value
            # all_updates = torch.zeros(cur_sz, 1).to(device)
            # all_updates[indices] = updates

            result_params[name] = p + updates.view(*p.size()) * out_mul
            result_params[name].retain_grad()
            
            offset += cur_sz
            

        # offset = 0
        # # zip the all_named_parameters with hidden_states2 and cell_states2
        # for (name, p), (indices, new_hidden), (_, new_cells) in zip(optimizee.all_named_parameters(), hidden_states2, cell_states2):
        #     cur_sz = int(np.prod(p.size()))
        #     # Update the specific hidden_states and cell_states
        #     for i in range(len(new_hidden)):
        #         # indices need to be offset into the actual "hidden_states" and "cell_states"
        #         offset_indices = indices + offset
        #         hidden_states[i][offset_indices] = new_hidden[i]
        #         cell_states[i][offset_indices] = new_cells[i]

        #     offset += cur_sz

                # Let's do some manual memory management here
        # del optimizee
        # gc.collect()
        # torch.cuda.empty_cache() # CHECK THIS
            

        if iteration % unroll == 0:
            if should_train:
                meta_opt.zero_grad()
                all_losses.backward()
                meta_opt.step()
                
            all_losses = None
                        
            optimizee = target_to_opt(**{k: detach_var(v) for k, v in result_params.items()}).to(device)
            hidden_states = [detach_var(v) for v in hidden_states]
            cell_states = [detach_var(v) for v in cell_states]
            
        else:
            optimizee = target_to_opt(**result_params).to(device)
            assert len(list(optimizee.all_named_parameters()))
            # hidden_states = hidden_states2
            # cell_states = cell_states2
        
        gc.collect()
        torch.cuda.empty_cache()
        
        

    if (return_optimizee):
        return all_losses_ever, optimizee
    return all_losses_ever

In [24]:
ind = np.random.choice(1000, 10, replace=False)
print(ind)
print(ind + 1000)

[468 697 108 374  70 904 906 457 584 109]
[1468 1697 1108 1374 1070 1904 1906 1457 1584 1109]


In [None]:
class QuadraticTarget:
    def __init__(self, **kwargs):
        self.W = Variable(torch.randn(10, 10)).to(device)
        self.y = Variable(torch.randn(10)).to(device)
        
    def get_loss(self, theta):
        return torch.sum((self.W.matmul(theta) - self.y)**2)
    
class QuadOptimizee(nn.Module):
    def __init__(self, theta=None):
        super().__init__()
        # Note: assuming the same optimization for theta as for
        # the function to find out itself.
        if theta is None:
            self.theta = nn.Parameter(torch.zeros(10))
        else:
            self.theta = theta
        
    def forward(self, target):
        return target.get_loss(self.theta)
    
    def all_named_parameters(self):
        return [('theta', self.theta)]


In [None]:
# for lr in tqdm([1.0, 0.3, 0.1, 0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001], 'all'):
for lr in tqdm([0.003]):
    print('Trying lr:', lr)
    print(fit_optimizer(QuadraticTarget, QuadOptimizee, lr=lr)[0])
    
loss, quad_optimizer = fit_optimizer(QuadraticTarget, QuadOptimizee, lr=0.003, n_epochs=100)
print(loss)

In [9]:
@cache.cache
def fit_normal(target_cls, target_to_opt, opt_class, n_tests=100, n_epochs=100, **kwargs):
    results = []
    for i in tqdm(range(n_tests), 'tests'):
        target = target_cls(training=False)
        optimizee = target_to_opt().to(device)
        optimizer = opt_class(optimizee.parameters(), **kwargs)
        total_loss = []
        for _ in range(n_epochs):
            loss = optimizee(target)
            
            total_loss.append(loss.data.cpu().numpy())
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        results.append(total_loss)
    return results

def find_best_lr_normal(target_cls, target_to_opt, opt_class, **extra_kwargs):
    best_loss = 1000000000000000.0
    best_lr = 0.0
    for lr in tqdm([1.0, 0.3, 0.1, 0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001], 'Learning rates'):
        try:
            loss = best_loss + 1.0
            loss = np.mean([np.sum(s) for s in fit_normal(target_cls, target_to_opt, opt_class, lr=lr, **extra_kwargs)])
        except RuntimeError:
            pass
        if loss < best_loss:
            best_loss = loss
            best_lr = lr
    return best_loss, best_lr

In [None]:
NORMAL_OPTS = [(torch.optim.Adam, {}), (torch.optim.RMSprop, {}), (torch.optim.SGD, {}), (torch.optim.SGD, {'nesterov': True, 'momentum': 0.9})]
OPT_NAMES = ['ADAM', 'RMSprop', 'SGD', 'NAG']
QUAD_LRS = [0.1, 0.03, 0.01, 0.01]
fit_data = np.zeros((100, 100, len(OPT_NAMES) + 1))
for i, ((opt, extra_kwargs), lr) in enumerate(zip(NORMAL_OPTS, QUAD_LRS)):
    np.random.seed(0)
    fit_data[:, :, i] = np.array(fit_normal(QuadraticLoss, QuadOptimizee, opt, lr=lr, **extra_kwargs))

opt = Optimizer_Grad().to(device)
opt.load_state_dict(quad_optimizer)
np.random.seed(0)
fit_data[:, :, len(OPT_NAMES)] = np.array([do_fit(opt, None, QuadraticLoss, QuadOptimizee, 1, 100, 100, out_mul=1.0, should_train=False) for _ in range(100)])


In [None]:
fit_data_est = np.mean(fit_data, axis=0)
plt.plot(fit_data_est[:, :], label=OPT_NAMES + ['LSTM'])
plt.yscale('log')
plt.xlabel('steps')
plt.ylabel('loss')
plt.title('Quadratic functions')
plt.legend()
plt.show()

# MLP on MNIST

In [None]:
class MNISTTarget:
    def __init__(self, training=True):
        dataset = datasets.MNIST('mnist_data', train=training, download=True,
                                 transform=torchvision.transforms.ToTensor()) 
        
        indices = list(range(len(dataset)))
        np.random.RandomState(10).shuffle(indices)
        if training:
            indices = indices[:len(indices) // 2] # Use only half of the data
        else:
            indices = indices[len(indices) // 2:]
        
        self.loader = torch.utils.data.DataLoader(
            dataset, batch_size=128,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices))
    
        self.batches = []
        self.cur_batch = 0

    def get_next_batch(self):
        if self.cur_batch >= len(self.batches):
            self.batches = []
            self.cur_batch = 0
            for batch_idx, (data, target) in enumerate(self.loader):
                self.batches.append((data.to(device), target.to(device)))
        result = self.batches[self.cur_batch]
        self.cur_batch += 1
        return result

class MNISTNet(nn.Module):
    def __init__(self, layer_sizes=[20], **kwargs):
        super().__init__()
        self.layer_sizes = [20]
        if (kwargs != {}):
            self.params = kwargs
        else:
            self.params = {}
            input_size = 28 * 28
            for i, layer_size in enumerate(layer_sizes):
                self.params['W' + str(i)] = nn.Parameter(torch.randn(input_size, layer_size) / np.sqrt(input_size))
                self.params['b' + str(i)] = nn.Parameter(torch.zeros(layer_size))
                input_size = layer_size

            self.params['W' + str(len(layer_sizes))] = nn.Parameter(torch.randn(input_size, 10) / np.sqrt(input_size))
            self.params['b' + str(len(layer_sizes))] = nn.Parameter(torch.zeros(10))

            self.mods = nn.ParameterList()
            for v in self.params.values():
                self.mods.append(v)
            
        self.activation = nn.Sigmoid()
        self.loss = nn.NLLLoss()

    def all_named_parameters(self):
        return [(k, v) for k, v in self.params.items()]

    def forward(self, target: MNISTTarget):
        x, y = target.get_next_batch()
        x = Variable(x.view(x.size()[0], 28*28)).to(device)
        y = Variable(y).to(device)
        # Ensure that the tensors both require_grad
        # Tensors are already moved to device in loader
        # x.requires_grad_()

        for i in range(len(self.layer_sizes)):
            x = self.activation(x @ self.params['W' + str(i)] + self.params['b' + str(i)])
        
        x = F.log_softmax(x @ self.params['W' + str(len(self.layer_sizes))] + self.params['b' + str(len(self.layer_sizes))], dim=1)

        loss = self.loss(x, y)
        return loss
        

In [None]:
QUAD_LRS = [0.03, 0.01, 1.0, 1.0]
N_TESTS = 20

NORMAL_OPTS = [(torch.optim.Adam, {})]
QUAD_LRS = [0.03]

fit_data = np.zeros((N_TESTS, 200, len(OPT_NAMES) + 1))
for i, ((opt, extra_kwargs), lr) in enumerate(zip(NORMAL_OPTS, QUAD_LRS)):
    np.random.seed(0)
    fit_data[:, :, i] = np.array(fit_normal(MNISTTarget, MNISTNet, opt, lr=lr, n_tests=N_TESTS, n_epochs=200, **extra_kwargs))


In [None]:
loss, mnist_optimizer = fit_optimizer(MNISTTarget, MNISTNet, lr=0.01, n_epochs=20, n_tests=20, preproc=True, out_mul=0.1)
print(loss)

In [None]:
NORMAL_OPTS = [(torch.optim.Adam, {}), (torch.optim.RMSprop, {}), (torch.optim.SGD, {}), (torch.optim.SGD, {'nesterov': True, 'momentum': 0.9})]
OPT_NAMES = ['ADAM', 'RMSprop', 'SGD', 'NAG']
QUAD_LRS = [0.03, 0.01, 1.0, 1.0]
fit_data = np.zeros((100, 100, len(OPT_NAMES) + 1))
for i, ((opt, extra_kwargs), lr) in enumerate(zip(NORMAL_OPTS, QUAD_LRS)):
    np.random.seed(0)
    fit_data[:, :, i] = np.array(fit_normal(MNISTTarget, MNISTNet, opt, lr=lr, **extra_kwargs))

opt = Optimizer_PP(preproc=True).to(device)
opt.load_state_dict(mnist_optimizer)
np.random.seed(0)
fit_data[:, :, len(OPT_NAMES)] = np.array([do_fit(opt, None, MNISTTarget, MNISTNet, 1, 100, 100, out_mul=0.1, should_train=False) for _ in range(100)])


In [None]:
fit_data_mean = np.mean(fit_data, axis=0)
plt.plot(fit_data_mean[:, :], label=OPT_NAMES + ['LSTM'])
plt.legend()
plt.show()

# ConvNet on MNIST

In [None]:
class MNISTTarget:
    def __init__(self, training=True):
        dataset = datasets.MNIST('mnist_data', train=training, download=True,
                                 transform=torchvision.transforms.ToTensor()) 
        
        indices = list(range(len(dataset)))
        np.random.RandomState(10).shuffle(indices)
        if training:
            indices = indices[:len(indices) // 2] # Use only half of the data
        else:
            indices = indices[len(indices) // 2:]
        
        self.loader = torch.utils.data.DataLoader(
            dataset, batch_size=128,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices))
    
        self.batches = []
        self.cur_batch = 0

    def get_next_batch(self):
        if self.cur_batch >= len(self.batches):
            self.batches = []
            self.cur_batch = 0
            for batch_idx, (data, target) in enumerate(self.loader):
                self.batches.append((data.to(device), target.to(device)))
        result = self.batches[self.cur_batch]
        self.cur_batch += 1
        return result

class MNISTResNet(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        if (kwargs != {}):
            self.params = kwargs
        else:
            self.params = {}
            input_size = 28 * 28

            # input image size: 28x28
            # after first conv, 3x3, padding 1 : 28x28
            # after second conv, 3x3, padding 1: 28x28
            # after maxpool: 14x14
            # after batch norm: 14x14
            # after fully connected: 10
            
            # Create the parameters for the first 3x3 convolution layer, padding 1
            self.params['W1'] = nn.Parameter(torch.randn(3, 1, 3, 3) / np.sqrt(3 * 1 * 3 * 3))
            self.params['b1'] = nn.Parameter(torch.zeros(3))

            # Create the parameters for the second 3x3 convolution layer, padding 1
            self.params['W2'] = nn.Parameter(torch.randn(3, 3, 3, 3) / np.sqrt(3 * 3 * 3 * 3))
            self.params['b2'] = nn.Parameter(torch.zeros(3))

            # Create the parameters for the final fully connected output layer, after a 2x2 max pooling
            self.params['W3'] = nn.Parameter(torch.randn(10, 3 * 14 * 14) / np.sqrt(10 * 3 * 14 * 14))
            self.params['b3'] = nn.Parameter(torch.zeros(10))

            self.mods = nn.ParameterList()
            for v in self.params.values():
                self.mods.append(v)
            
        self.activation = nn.ReLU()
        self.norm = nn.BatchNorm2d(3)
        self.loss = nn.NLLLoss()

    def all_named_parameters(self):
        return [(k, v) for k, v in self.params.items()]

    def forward(self, target: MNISTTarget):
        x, y = target.get_next_batch()
        x = Variable(x.view(x.size()[0], 28*28)).to(device)
        y = Variable(y).to(device)

        residual = x

        # Apply the first convolution, 3x3 with padding 1 and 3 channels
        x = F.conv2d(x.view(x.size()[0], 1, 28, 28), self.params['W1'], bias=self.params['b1'], padding=1)
        x = self.activation(x)

        # shape: Bx3x28x28

        # Apply the second convolution, 3x3 with padding 1
        x = F.conv2d(x, self.params['W2'], bias=self.params['b2'], padding=1)
        # shape: Bx3x28x28

        # Apply residual connection, x shape is Bx3x28x28, residual shape is Bx784
        # Apply residual per filter channel, need to duplicate residual 3 times
        residual = residual.view(residual.size()[0], 1, 28, 28)
        residual = residual.repeat(1, 3, 1, 1)
        x = x + residual
        

        # Apply activation
        x = self.activation(x)

        # Apply max pooling
        x = F.max_pool2d(x, 2)
        
        # Apply batch normalization
        x = self.norm(x)

        # Apply the final fully connected layer
        x = x.view(x.size()[0], -1)
        x = F.linear(x, self.params['W3'], bias=self.params['b3'])
        x = F.log_softmax(x, dim=1)

        loss = self.loss(x, y)
        return loss
        

In [None]:
loss, mnist_res_optimizer = fit_optimizer(MNISTTarget, MNISTResNet, lr=0.01, n_epochs=20, n_tests=20, preproc=True, out_mul=0.1)
print(loss)

In [None]:
NORMAL_OPTS = [(torch.optim.Adam, {})]
OPT_NAMES = ['ADAM']
QUAD_LRS = [0.02]
fit_data = np.zeros((100, 100, len(OPT_NAMES) + 1))
for i, ((opt, extra_kwargs), lr) in enumerate(zip(NORMAL_OPTS, QUAD_LRS)):
    np.random.seed(0)
    fit_data[:, :, i] = np.array(fit_normal(MNISTTarget, MNISTResNet, opt, lr=lr, **extra_kwargs))

opt = Optimizer_PP(preproc=True).to(device)
opt.load_state_dict(mnist_res_optimizer)
np.random.seed(0)
fit_data[:, :, len(OPT_NAMES)] = np.array([do_fit(opt, None, MNISTTarget, MNISTResNet, 1, 100, 100, out_mul=0.1, should_train=False) for _ in range(100)])


In [None]:
fit_data_mean = np.mean(fit_data, axis=0)
plt.plot(fit_data_mean[:, :], label=OPT_NAMES + ['LSTM'])
plt.legend()
plt.show()

# Transfering ConvNet on MNIST to FashionMNIST

In [None]:
class FashionMNISTTarget:
    def __init__(self, training=True):
        dataset = datasets.FashionMNIST('fashion_mnist_data', train=training, download=True,
                                 transform=torchvision.transforms.ToTensor()) 
        
        indices = list(range(len(dataset)))
        np.random.RandomState(10).shuffle(indices)
        if training:
            indices = indices[:len(indices) // 2] # Use only half of the data
        else:
            indices = indices[len(indices) // 2:]
        
        self.loader = torch.utils.data.DataLoader(
            dataset, batch_size=128,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices))
    
        self.batches = []
        self.cur_batch = 0

    def get_next_batch(self):
        if self.cur_batch >= len(self.batches):
            self.batches = []
            self.cur_batch = 0
            for batch_idx, (data, target) in enumerate(self.loader):
                self.batches.append((data.to(device), target.to(device)))
        result = self.batches[self.cur_batch]
        self.cur_batch += 1
        return result

In [None]:
opt = Optimizer_PP(preproc=True).to(device)
opt.load_state_dict(mnist_res_optimizer)
losses, optimizee = do_fit(opt, None, FashionMNISTTarget, MNISTResNet, 1, 200, 200, out_mul=0.1, should_train=False, return_optimizee=True)

In [None]:
plt.plot(losses)
plt.show()

In [None]:
NORMAL_OPTS = [(torch.optim.Adam, {})]
OPT_NAMES = ['ADAM']
QUAD_LRS = [0.02]
fit_data = np.zeros((100, 100, len(OPT_NAMES) + 1))
for i, ((opt, extra_kwargs), lr) in enumerate(zip(NORMAL_OPTS, QUAD_LRS)):
    np.random.seed(0)
    fit_data[:, :, i] = np.array(fit_normal(FashionMNISTTarget, MNISTResNet, opt, lr=lr, **extra_kwargs))
    print(fit_data[:, -1, i].mean(), fit_data[:, -1, i].std())

opt = Optimizer_PP(preproc=True).to(device)
opt.load_state_dict(mnist_res_optimizer)
np.random.seed(0)
fit_data[:, :, len(OPT_NAMES)] = np.array([do_fit(opt, None, FashionMNISTTarget, MNISTResNet, 1, 100, 100, out_mul=0.1, should_train=False) for _ in range(100)])


In [None]:
fit_data_mean = np.mean(fit_data, axis=0)
plt.plot(fit_data_mean[:, :], label=OPT_NAMES + ['LSTM'])
plt.legend()
plt.show()

for i in enumerate(zip(NORMAL_OPTS, QUAD_LRS)):
    print(i)
    print(fit_data_mean[:, :])


# BERT

In [9]:
# Creates batches of (x, y) where x is of (Seq, VocabSize) [ one of vocab_size is masked ] and y is of (1) [ vocab index of masked element ]
class MLMTarget:
    def __init__(self, training=True):
        dataset = load_dataset("openwebtext", split=("train")).with_format("torch", device=device)

        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        
        indices = list(range(len(dataset)))
        np.random.RandomState(10).shuffle(indices)

        self.loader = torch.utils.data.DataLoader(
            dataset, batch_size=2,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices))
    
        self.batches = []
        self.cur_batch = 0

        self.iter = iter(self.loader)

    def get_next_batch(self):
        # if self.cur_batch >= len(self.batches):
        #     self.batches = []
        #     self.cur_batch = 0
        #     for batch_idx, data in enumerate(self.loader):
        #         self.batches.append(data)
                
        # result = self.batches[self.cur_batch]
        # self.cur_batch += 1

        result = next(self.iter)

        # tokenize the data, ensuring it's on device
        result = self.tokenizer(result["text"], padding=True, truncation=True, return_tensors="pt").to(device)

        # get the random indices for each sequence to mask
        mask_indices = torch.randint(0, result["input_ids"].size()[1], result["input_ids"].size()[:1] + (1,)).to(device)

        # save the targets
        targets = result["input_ids"][torch.arange(result["input_ids"].size()[0]), mask_indices.squeeze()].to(device)

        # mask the data
        result["input_ids"][torch.arange(result["input_ids"].size()[0]), mask_indices.squeeze()] = self.tokenizer.mask_token_id

        # save the batch


        return (result, targets)

In [18]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tokenizer.mask_token_id

103

In [5]:
mlm = MLMTarget()

Found cached dataset openwebtext (/home/dadur604604/.cache/huggingface/datasets/openwebtext/plain_text/1.0.0/6f68e85c16ccc770c0dd489f4008852ea9633604995addd0cd76e293aed9e521)


In [15]:
d,t = mlm.get_next_batch()

A
B
C
D
E


In [24]:
d

{'input_ids': tensor([[  101,  2043, 18431,  ...,  1000,  2056,   102],
        [  101, 19962, 22599,  ...,     0,     0,     0],
        [  101,  2007,  1037,  ...,     0,     0,     0],
        ...,
        [  101,  2131,  1996,  ...,  3357,  2000,   102],
        [  101,  3602,  1024,  ...,  1010,  2293,   102],
        [  101,  3745,  3025,  ...,  2003,  1996,   102]], device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]], device='cuda:0')}

In [10]:
# Tiny BERT model with 2 transformer layers, and a hidden size of 128.
# Input and output are for masked language modeling.
# Expects input of shape (B, Seq, Vocab_Size), where one of the sequence elements is masked.
# Output is shape (B, Vocab_Size), where the masked element is predicted.
# Labels are expected of shape (B,) where each element is the vocab-size index of the masked element.
class BERTMLMNet(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.hidden_size = 128
        self.intermediate_size = 128
        self.vocab_size  = 30522
        self.num_attention_heads = 2
        self.num_layers = 1

        if (kwargs != {}):
            self.params = kwargs
        else:
            self.params = {}
            self.layer_norms = nn.ModuleList()
            # Embedding weights size: (vocab_size, hidden_size)
            self.params["W_e"] = nn.Parameter(torch.randn(self.vocab_size, self.hidden_size))
            self.params["b_e"] = nn.Parameter(torch.randn(self.hidden_size))

            # Define the Bert layers
            # Each layer contains: (1) self-attention, (2) dense layer (hidden -> intermediate), (3) dense layer (intermediate -> hidden)
            for i in range(self.num_layers):
                # Self attention weights size: (3*hidden_size, hidden_size)
                self.params["W_a" + str(i)] = nn.Parameter(torch.randn(3*self.hidden_size, self.hidden_size))
                # Self attention bias size: (hidden_size)
                self.params["b_a" + str(i)] = nn.Parameter(torch.randn(3*self.hidden_size))
                self.params["W_ao" + str(i)] = nn.Parameter(torch.randn(self.hidden_size, self.hidden_size))
                self.params["b_ao" + str(i)] = nn.Parameter(torch.randn(self.hidden_size))

                # dense intermediate
                # self.params["W_i" + str(i)] = nn.Parameter(torch.randn(self.hidden_size, self.intermediate_size))
                # self.params["b_i" + str(i)] = nn.Parameter(torch.randn(self.intermediate_size))

                # dense output
                # self.params["W_o" + str(i)] = nn.Parameter(torch.randn(self.intermediate_size, self.hidden_size))
                # self.params["b_o" + str(i)] = nn.Parameter(torch.randn(self.hidden_size))
                # self.layer_norms.append(nn.LayerNorm(self.hidden_size))

            # Define the weights for masked language prediction head
            # After 2nd attention layer (size: hidden_size), follow by a dense layer, then activation, then fc (from hidden_size to vocab_size)
            self.params["W_dense"] = nn.Parameter(torch.randn(self.hidden_size, self.hidden_size))
            self.params["b_dense"] = nn.Parameter(torch.randn(self.hidden_size))

            self.params["W_output"] = nn.Parameter(torch.randn(self.hidden_size, self.vocab_size))
            self.params["b_output"] = nn.Parameter(torch.randn(self.vocab_size))

            

            self.mods = nn.ParameterList()
            for v in self.params.values():
                self.mods.append(v)
            
        self.layer_norm_pred = nn.LayerNorm(self.hidden_size)
        self.activation = nn.GELU()
        self.loss = nn.CrossEntropyLoss()

    def all_named_parameters(self):
        return [(k, v) for k, v in self.params.items()]

    def forward(self, target: MLMTarget):
        x, y = target.get_next_batch()
        # Convert attention mask to bool mask
        attention_mask = x["attention_mask"].bool()
        x = x["input_ids"]
        y = y.long()

        # x is shape (B, Seq)
        # x is tensor of vocab indices, convert to one-hot
        x = F.one_hot(x, self.vocab_size)
        x = x.float()        
        
        # x is shape (B, Seq, Vocab_Size)
        # y is shape (B, Vocab_Size)


        # Embedding layer
        x = x @ self.params["W_e"] + self.params["b_e"]
        # x is shape (B, Seq, Hidden_Size)


        residual = x

        # Transformer Layers
        for i in range(self.num_layers):
            # Self attention
            # x should be seq first, batch second
            x = x.permute(1, 0, 2)
            x = F.multi_head_attention_forward(
                x, x, x, self.hidden_size, self.num_attention_heads, self.params["W_a" + str(i)], self.params["b_a" + str(i)], None, None, False, 0.0, self.params["W_ao" + str(i)], self.params["b_ao" + str(i)])[0]
            x = x.permute(1, 0, 2)

            # Dense intermediate
            # x = x @ self.params["W_i" + str(i)] + self.params["b_i" + str(i)]
            x = self.activation(x)

            # # Output (dense -> residual -> layernorm)
            # x = x @ self.params["W_o" + str(i)] + self.params["b_o" + str(i)]
            # x = x + residual
            # x = self.layer_norms[i](x)

            # residual = x

        
        # Masked language prediction head (dense -> activation -> layernorm -> fc decoder)
        # x is shape (B, Seq, Hidden_Size)
        x = x @ self.params["W_dense"] + self.params["b_dense"]
        x = self.activation(x)
        x = self.layer_norm_pred(x)
        x = x @ self.params["W_output"] + self.params["b_output"]

        # x is shape (B, Seq, Vocab_Size) [ vocab_size is logits ]

        loss = self.loss(x[:, 0, :], y)
        print(loss)
        return loss
        

In [None]:
sum([p.nelement()*20 for p in BERTMLMNet().parameters()])*2*2*2

In [None]:
bmlmnet = BERTMLMNet()
bmlmnet.to(device)

In [None]:
bmlmnet.forward(mlm)

In [31]:
del bmlmnet

In [28]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [20]:
import objgraph

In [None]:
import gc
i = 0
for obj in gc.get_objects():
    if i > 100:
        break

    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            i += 1
            print(type(obj), obj.size())
            # If object type is a torch Tensor
            if type(obj) == torch.nn.Parameter:
                objgraph.show_refs(obj, filename='ref_topo.png')
                break

            referrers = gc.get_referrers(obj)
            for ref in referrers:
                print('\t', type(ref))
    except:
        pass

In [27]:
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   5354 MiB |   6411 MiB | 662364 MiB | 657010 MiB |
|       from large pool |   5312 MiB |   6323 MiB | 661823 MiB | 656510 MiB |
|       from small pool |     41 MiB |    102 MiB |    541 MiB |    499 MiB |
|---------------------------------------------------------------------------|
| Active memory         |   5354 MiB |   6411 MiB | 662364 MiB | 657010 MiB |
|       from large pool |   5312 MiB |   6323 MiB | 661823 MiB | 656510 MiB |
|       from small pool |     41 MiB |    102 MiB |    541 MiB |    499 MiB |
|---------------------------------------------------------------

In [29]:
loss, mlm_optimizer = fit_optimizer(MLMTarget, BERTMLMNet, lr=0.01, n_epochs=10, unroll=5, n_tests=5, preproc=True, out_mul=0.1)
print(loss)

epochs:   0%|          | 0/10 [00:00<?, ?it/s]

iterations:   0%|          | 0/20 [00:00<?, ?it/s]

do_fit


Found cached dataset openwebtext (/home/dadur604604/.cache/huggingface/datasets/openwebtext/plain_text/1.0.0/6f68e85c16ccc770c0dd489f4008852ea9633604995addd0cd76e293aed9e521)


: 

: 

In [28]:
NORMAL_OPTS = [(torch.optim.Adam, {})]
OPT_NAMES = ['ADAM']
QUAD_LRS = [0.02]
fit_data = []
for i, ((opt, extra_kwargs), lr) in enumerate(zip(NORMAL_OPTS, QUAD_LRS)):
    np.random.seed(0)

    losses = np.array(fit_normal(MLMTarget, BERTMLMNet, opt, n_tests=1, n_epochs=20, lr=lr, **extra_kwargs))
    fit_data.append(losses)

In [29]:
losses

array([[56.866653, 46.316788, 43.42909 , 32.41855 , 25.812342, 31.459908,
        31.283218, 28.418797, 32.26407 , 30.252611, 27.90635 , 20.043835,
        23.69626 , 16.425978, 27.682356, 15.457117, 13.917219, 15.308836,
        23.49728 , 18.797798]], dtype=float32)