<a href="https://colab.research.google.com/github/nicolekwli/final-year-project/blob/main/notebooks/fyp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import os

The class for the actual model with 3 different types.

1. The CPC model according to Oord et al.
2. GIM model, where the last encoding layer is trained together with the autoregressor
3. GIM model in which the last autoregressive layer is trained independently

In [None]:
class FullModel(nn.Module):
    def __init__(
        self,
        opt,
        kernel_sizes,
        strides,
        padding,
        enc_hidden,
        reg_hidden,
        calc_accuracy=False,
    ):
        """
        Entire CPC model that can be split into smaller chunks for training
        """
        super(FullModel, self).__init__()

        self.opt = opt
        self.reg_hidden = reg_hidden
        self.enc_hidden = enc_hidden

        # load model
        self.fullmodel = nn.ModuleList([])

        if self.opt.model_splits == 1:
            # CPC model
            self.fullmodel.append(
                independent_module.IndependentModule(
                    opt,
                    enc_kernel_sizes=kernel_sizes,
                    enc_strides=strides,
                    enc_padding=padding,
                    enc_hidden=enc_hidden,
                    reg_hidden=reg_hidden,
                    calc_accuracy=calc_accuracy,
                )
            )
        elif self.opt.model_splits == 5:
            # GIM model, where the last encoding layer is trained together with the autoregressor
            enc_input = 1
            last_idx = len(kernel_sizes) - 1

            for i in range(last_idx):
                self.fullmodel.append(
                    independent_module.IndependentModule(
                        opt,
                        enc_input=enc_input,
                        enc_kernel_sizes=[kernel_sizes[i]],
                        enc_strides=[strides[i]],
                        enc_padding=[padding[i]],
                        enc_hidden=enc_hidden,
                        reg_hidden=reg_hidden,
                        use_autoregressive=self.opt.use_autoregressive,
                        calc_accuracy=calc_accuracy,
                    )
                )
                enc_input = enc_hidden

            self.fullmodel.append(
                independent_module.IndependentModule(
                    opt,
                    enc_input=enc_input,
                    enc_kernel_sizes=[kernel_sizes[last_idx]],
                    enc_strides=[strides[last_idx]],
                    enc_padding=[padding[last_idx]],
                    enc_hidden=enc_hidden,
                    reg_hidden=reg_hidden,
                    use_autoregressive=True,
                    calc_accuracy=calc_accuracy,
                )
            )
        elif (
            self.opt.model_splits == 6
        ):  # GIM model in which the last autoregressive layer is trained independently
            enc_input = 1

            for i in range(len(kernel_sizes)):
                self.fullmodel.append(
                    independent_module.IndependentModule(
                        opt,
                        enc_input=enc_input,
                        enc_kernel_sizes=[kernel_sizes[i]],
                        enc_strides=[strides[i]],
                        enc_padding=[padding[i]],
                        enc_hidden=enc_hidden,
                        reg_hidden=reg_hidden,
                        use_autoregressive=self.opt.use_autoregressive,
                        calc_accuracy=calc_accuracy,
                    )
                )
                enc_input = enc_hidden

            if not self.opt.use_autoregressive:
                # append separate autoregressive layer
                self.fullmodel.append(
                    independent_module.IndependentModule(
                        opt,
                        enc_input=enc_input,
                        enc_hidden=enc_hidden,
                        reg_hidden=reg_hidden,
                        use_encoder=False,
                        enc_kernel_sizes=None,
                        enc_strides=None,
                        enc_padding=None,
                        use_autoregressive=True,
                        calc_accuracy=calc_accuracy,
                    )
                )
        else:
            raise Exception("Invalid option for opt.model_splits")

    def forward(self, x, filename=None, start_idx=None, n=6):
        model_input = x

        cur_device = utils.get_device(self.opt, x)

        # first dimension is used for concatenating results from different GPUs
        loss = torch.zeros(1, len(self.fullmodel), device=cur_device)
        accuracy = torch.zeros(1, len(self.fullmodel), device=cur_device)

        if n == 6:  # train all layers at once
            for idx, layer in enumerate(self.fullmodel):
                loss[:, idx], accuracy[:, idx], _, z = layer(
                    model_input, filename, start_idx
                )
                model_input = z.permute(0, 2, 1).detach()
        else:
            """
            forward to the layer that we want to train and only output that layer's loss
            (all other values stay at zero initialization)
            This does not reap the memory benefits that would be possible if we trained layers completely separately 
            (by training a layer and saving its output as the dataset to train the next layer on), but enables us 
            to test the behaviour of the model for greedy iterative training
            """
            assert (
                self.opt.model_splits == 5 or self.opt.model_splits == 6
            ), "Works only for GIM model training"

            for idx, layer in enumerate(self.fullmodel[: n + 1]):
                if idx == n:
                    loss[:, idx], accuracy[:, idx], _, _ = layer(
                        model_input, filename, start_idx
                    )
                else:
                    _, z = layer.get_latents(model_input)
                    model_input = z.permute(0, 2, 1).detach()

        return loss

    def forward_through_n_layers(self, x, n):
        if self.opt.model_splits == 1:
            if n > 4:
                model_input = x
                for idx, layer in enumerate(self.fullmodel):
                    c, z = layer.get_latents(model_input)
                    model_input = z.permute(0, 2, 1).detach()
                x = c
            else:
                x = self.fullmodel[0].encoder.forward_through_n_layers(
                    x, n+1
                )
                x = x.permute(0, 2, 1)
        elif self.opt.model_splits == 6 or self.opt.model_splits == 5:
            model_input = x
            for idx, layer in enumerate(self.fullmodel[: n + 1]):
                c, z = layer.get_latents(model_input)
                model_input = z.permute(0, 2, 1).detach()
            if n < 5:
                x = z
            else:
                x = c

        return x

Utilities for the model.

In [None]:
def distribute_over_GPUs(opt, model, num_GPU):
    ## distribute over GPUs
    if opt.device.type != "cpu":
        if num_GPU is None:
            model = nn.DataParallel(model)
            num_GPU = torch.cuda.device_count()
            opt.batch_size_multiGPU = opt.batch_size * num_GPU
        else:
            assert (
                num_GPU <= torch.cuda.device_count()
            ), "You cant use more GPUs than you have."
            model = nn.DataParallel(model, device_ids=list(range(num_GPU)))
            opt.batch_size_multiGPU = opt.batch_size * num_GPU
    else:
        model = nn.DataParallel(model)
        opt.batch_size_multiGPU = opt.batch_size

    model = model.to(opt.device)
    print("Let's use", num_GPU, "GPUs!")

    return model, num_GPU


def genOrthgonal(dim):
    a = torch.zeros((dim, dim)).normal_(0, 1)
    q, r = torch.qr(a)
    d = torch.diag(r, 0).sign()
    diag_size = d.size(0)
    d_exp = d.view(1, diag_size).expand(diag_size, diag_size)
    q.mul_(d_exp)
    return q


def makeDeltaOrthogonal(weights, gain):
    rows = weights.size(0)
    cols = weights.size(1)
    if rows > cols:
        print("In_filters should not be greater than out_filters.")
    weights.data.fill_(0)
    dim = max(rows, cols)
    q = genOrthgonal(dim)
    mid1 = weights.size(2) // 2
    mid2 = weights.size(3) // 2
    with torch.no_grad():
        weights[:, :, mid1, mid2] = q[: weights.size(0), : weights.size(1)]
        weights.mul_(gain)


def reload_weights(opt, model, optimizer, reload_model):
    ## reload weights for training of the linear classifier
    if (opt.model_type == 0) and reload_model:  # or opt.model_type == 2)
        print("Loading weights from ", opt.model_path)

        if opt.experiment == "audio":
            model.load_state_dict(
                torch.load(
                    os.path.join(opt.model_path, "model_{}.ckpt".format(opt.model_num)),
                    map_location=opt.device.type,
                )
            )
        else:
            for idx, layer in enumerate(model.module.encoder):
                model.module.encoder[idx].load_state_dict(
                    torch.load(
                        os.path.join(
                            opt.model_path,
                            "model_{}_{}.ckpt".format(idx, opt.model_num),
                        ),
                         map_location=opt.device.type,
                    )
                )

    ## reload weights and optimizers for continuing training
    elif opt.start_epoch > 0:
        print("Continuing training from epoch ", opt.start_epoch)

        if opt.experiment == "audio":
            model.load_state_dict(
                torch.load(
                    os.path.join(
                        opt.model_path, "model_{}.ckpt".format(opt.start_epoch)
                    ),
                    map_location=opt.device.type,
                ),
                strict=False,
            )
        else:
            for idx, layer in enumerate(model.module.encoder):
                model.module.encoder[idx].load_state_dict(
                    torch.load(
                        os.path.join(
                            opt.model_path,
                            "model_{}_{}.ckpt".format(idx, opt.start_epoch),
                        ),
                        map_location=opt.device.type,
                    )
                )

        for i, optim in enumerate(optimizer):
            if opt.model_splits > 3 and i > 2:
                break
            optim.load_state_dict(
                torch.load(
                    os.path.join(
                        opt.model_path,
                        "optim_{}_{}.ckpt".format(str(i), opt.start_epoch),
                    ),
                    map_location=opt.device.type,
                )
            )
    else:
        print("Randomly initialized model")

    return model, optimizer

This function initialises the model with dimensions given in the Oord et al paper, although padding was not mentioned.

This function also makes sure only one GPU is used for when we're doign supervised loss.

The function also initialises the ADAM optimiser.





In [None]:
def load_model_and_optimizer(
    opt, reload_model=False, calc_accuracy=False, num_GPU=None
):

    # original dimensions given in CPC paper (Oord et al.)
    kernel_sizes = [10, 8, 4, 4, 4]
    strides = [5, 4, 2, 2, 2]
    padding = [2, 2, 2, 2, 1]
    enc_hidden = 512
    reg_hidden = 256

    ## initialize model
    model = full_model.FullModel(
        opt,
        kernel_sizes=kernel_sizes,
        strides=strides,
        padding=padding,
        enc_hidden=enc_hidden,
        reg_hidden=reg_hidden,
        calc_accuracy=calc_accuracy,
    )

    # run on only one GPU for supervised losses
    if opt.loss == 2 or opt.loss == 1:
        num_GPU = 1

    model, num_GPU = model_utils.distribute_over_GPUs(opt, model, num_GPU=num_GPU)

    """ initialize optimizers
    We need to have a separate optimizer for every individually trained part of the network
    as calling optimizer.step() would otherwise cause all parts of the network to be updated
    even when their respective gradients are zero (due to momentum)
    """
    optimizer = []
    for idx, layer in enumerate(model.module.fullmodel):
        if isinstance(opt.learning_rate, list):
            cur_lr = opt.learning_rate[idx]
        else:
            cur_lr = opt.learning_rate
        optimizer.append(torch.optim.Adam(layer.parameters(), lr=cur_lr))

    model, optimizer = model_utils.reload_weights(opt, model, optimizer, reload_model)

    model.train()
    print(model)

    return model, optimizer