# Define the Model

In [None]:
import types
import torch
import torch.nn as nn
from torch.autograd import Function


def CountSketchFn_forward(h, s, output_size, x, force_cpu_scatter_add=False):
    x_size = tuple(x.size())

    s_view = (1,) * (len(x_size) - 1) + (x_size[-1],)

    out_size = x_size[:-1] + (output_size,)

    # Broadcast s and compute x * s
    s = s.view(s_view)
    xs = x * s

    # Broadcast h then compute h:
    # out[h_i] += x_i * s_i
    h = h.view(s_view).expand(x_size)

    if force_cpu_scatter_add:
        out = x.new(*out_size).zero_().cpu()
        return out.scatter_add_(-1, h.cpu(), xs.cpu()).cuda()
    else:
        out = x.new(*out_size).zero_()
        return out.scatter_add_(-1, h, xs)


def CountSketchFn_backward(h, s, x_size, grad_output):
    s_view = (1,) * (len(x_size) - 1) + (x_size[-1],)

    s = s.view(s_view)
    h = h.view(s_view).expand(x_size)

    grad_x = grad_output.gather(-1, h)
    grad_x = grad_x * s
    return grad_x


class CountSketchFn(Function):

    @staticmethod
    def forward(ctx, h, s, output_size, x, force_cpu_scatter_add=False):
        x_size = tuple(x.size())

        ctx.save_for_backward(h, s)
        ctx.x_size = tuple(x.size())

        return CountSketchFn_forward(h, s, output_size, x, force_cpu_scatter_add)

    @staticmethod
    def backward(ctx, grad_output):
        h, s = ctx.saved_variables

        grad_x = CountSketchFn_backward(h, s, ctx.x_size, grad_output)
        return None, None, None, grad_x


class CountSketch(nn.Module):
    r"""Compute the count sketch over an input signal.

    .. math::

        out_j = \sum_{i : j = h_i} s_i x_i

    Args:
        input_size (int): Number of channels in the input array
        output_size (int): Number of channels in the output sketch
        h (array, optional): Optional array of size input_size of indices in the range [0,output_size]
        s (array, optional): Optional array of size input_size of -1 and 1.

    .. note::

        If h and s are None, they will be automatically be generated using LongTensor.random_.

    Shape:
        - Input: (...,input_size)
        - Output: (...,output_size)

    References:
        Yang Gao et al. "Compact Bilinear Pooling" in Proceedings of IEEE Conference on Computer Vision and Pattern Recognition (2016).
        Akira Fukui et al. "Multimodal Compact Bilinear Pooling for Visual Question Answering and Visual Grounding", arXiv:1606.01847 (2016).
    """

    def __init__(self, input_size, output_size, h=None, s=None):
        super(CountSketch, self).__init__()

        self.input_size = input_size
        self.output_size = output_size

        if h is None:
            h = torch.LongTensor(input_size).random_(0, output_size)
        if s is None:
            s = 2 * torch.Tensor(input_size).random_(0, 2) - 1

        # The Variable h being a list of indices,
        # If the type of this module is changed (e.g. float to double),
        # the variable h should remain a LongTensor
        # therefore we force float() and double() to be no-ops on the variable h.
        def identity(self):
            return self

        h.float = types.MethodType(identity, h)
        h.double = types.MethodType(identity, h)

        self.register_buffer('h', h)
        self.register_buffer('s', s)

    def forward(self, x):
        x_size = list(x.size())

        assert (x_size[-1] == self.input_size)

        return CountSketchFn.apply(self.h, self.s, self.output_size, x)


def ComplexMultiply_forward(X_re, X_im, Y_re, Y_im):
    Z_re = torch.addcmul(X_re * Y_re, X_im, Y_im, value= -1)
    Z_im = torch.addcmul(X_re * Y_im, X_im, Y_re, value= 1)
    return Z_re, Z_im


def ComplexMultiply_backward(x_re, x_im, y_re, y_im, grad_z_re, grad_z_im):
    grad_x_re = torch.addcmul(grad_z_re * y_re, grad_z_im, y_im, value= 1)
    grad_x_im = torch.addcmul(grad_z_im * y_re, grad_z_re, y_im, value= -1)
    grad_y_re = torch.addcmul(grad_z_re * x_re, grad_z_im, x_im, value= 1)
    grad_y_im = torch.addcmul(grad_z_im * x_re, grad_z_re, x_im, value= -1)
    return grad_x_re, grad_x_im, grad_y_re, grad_y_im


class ComplexMultiply(Function):
    @staticmethod
    def forward(ctx, x_re, x_im, y_re, y_im):
        ctx.save_for_backward(x_re, x_im, y_re, y_im)
        return ComplexMultiply_forward(x_re, x_im, y_re, y_im)

    @staticmethod
    def backward(ctx, grad_z_re, grad_z_im):
        x_re, x_im, y_re, y_im = ctx.saved_tensors
        return ComplexMultiply_backward(x_re, x_im, y_re, y_im, grad_z_re, grad_z_im)


class CompactBilinearPoolingFn(Function):
    @staticmethod
    def forward(ctx, h1, s1, h2, s2, output_size, x, y, force_cpu_scatter_add=False):
        ctx.save_for_backward(h1, s1, h2, s2, x, y)
        ctx.x_size = tuple(x.size())
        ctx.y_size = tuple(y.size())
        ctx.force_cpu_scatter_add = force_cpu_scatter_add
        ctx.output_size = output_size

        # Compute the count sketch of each input: x ==> px, y ==> py
        px = CountSketchFn_forward(h1, s1, output_size, x, force_cpu_scatter_add)
        fx = torch.fft.rfft(px, dim=1)
        fx_re = fx.real
        fx_im = fx.imag
        del px
        py = CountSketchFn_forward(h2, s2, output_size, y, force_cpu_scatter_add)
        fy = torch.fft.rfft(py, dim=1)
        fy_re = fy.real
        fy_im = fy.imag
        del py

        # Convolution of the two sketch using an FFT.
        # Compute the FFT of each sketch

        # Complex multiplication: element-wise product
        prod_re, prod_im = ComplexMultiply_forward(fx_re, fx_im, fy_re, fy_im)
        complex_prod = torch.complex(prod_re, prod_im)

        # Back to real domain
        # The imaginary part should be zero's
        re = torch.fft.irfft(complex_prod, n=output_size)

        return re

    @staticmethod
    def backward(ctx, grad_output):
        h1, s1, h2, s2, x, y = ctx.saved_tensors

        # Recompute part of the forward pass to get the input to the complex product
        # Compute the count sketch of each input
        px = CountSketchFn_forward(h1, s1, ctx.output_size, x, ctx.force_cpu_scatter_add)
        py = CountSketchFn_forward(h2, s2, ctx.output_size, y, ctx.force_cpu_scatter_add)

        # Then convert the output to Fourier domain
        grad_output = grad_output.contiguous()
        grad_prod = torch.fft.rfft(grad_output, dim=1)
        grad_re_prod = grad_prod.real
        grad_im_prod = grad_prod.imag

        # Compute the gradient of x first then y

        # Gradient of x
        # Recompute fy
        fy = torch.fft.rfft(py, dim=1)
        re_fy = fy.real
        im_fy = fy.imag
        del py
        # Compute the gradient of fx, then back to temporal space
        grad_re_fx = torch.addcmul(grad_re_prod * re_fy, grad_im_prod, im_fy, value= 1)
        grad_im_fx = torch.addcmul(grad_im_prod * re_fy, grad_re_prod, im_fy, value= -1)
        complex_fx = torch.complex(grad_re_fx, grad_im_fx)
        grad_fx = torch.fft.irfft(complex_fx, n=ctx.output_size)
        # Finally compute the gradient of x
        grad_x = CountSketchFn_backward(h1, s1, ctx.x_size, grad_fx)
        del re_fy, im_fy, grad_re_fx, grad_im_fx, grad_fx

        # Gradient of y
        # Recompute fx
        fx = torch.fft.rfft(px, dim=1)
        re_fx = fx.real
        im_fx = fx.imag
        del px
        # Compute the gradient of fy, then back to temporal space
        grad_re_fy = torch.addcmul(grad_re_prod * re_fx, grad_im_prod, im_fx, value= 1)
        grad_im_fy = torch.addcmul(grad_im_prod * re_fx, grad_re_prod, im_fx, value= -1)
        complex_fy = torch.complex(grad_re_fy, grad_im_fy)
        grad_fy = torch.fft.irfft(complex_fy, n=ctx.output_size)
        # Finally compute the gradient of y
        grad_y = CountSketchFn_backward(h2, s2, ctx.y_size, grad_fy)
        del re_fx, im_fx, grad_re_fy, grad_im_fy, grad_fy

        return None, None, None, None, None, grad_x, grad_y, None


class CompactBilinearPooling(nn.Module):
    r"""Compute the compact bilinear pooling between two input array x and y

    .. math::

        out = \Psi (x,h_1,s_1) \ast \Psi (y,h_2,s_2)

    Args:
        input_size1 (int): Number of channels in the first input array
        input_size2 (int): Number of channels in the second input array
        output_size (int): Number of channels in the output array
        h1 (array, optional): Optional array of size input_size of indices in the range [0,output_size]
        s1 (array, optional): Optional array of size input_size of -1 and 1.
        h2 (array, optional): Optional array of size input_size of indices in the range [0,output_size]
        s2 (array, optional): Optional array of size input_size of -1 and 1.
        force_cpu_scatter_add (boolean, optional): Force the scatter_add operation to run on CPU for testing purposes

    .. note::

        If h1, s1, s2, h2 are None, they will be automatically be generated using LongTensor.random_.

    Shape:
        - Input 1: (...,input_size1)
        - Input 2: (...,input_size2)
        - Output: (...,output_size)

    References:
        Yang Gao et al. "Compact Bilinear Pooling" in Proceedings of IEEE Conference on Computer Vision and Pattern Recognition (2016).
        Akira Fukui et al. "Multimodal Compact Bilinear Pooling for Visual Question Answering and Visual Grounding", arXiv:1606.01847 (2016).
    """

    def __init__(self, input1_size, input2_size, output_size, h1=None, s1=None, h2=None, s2=None,
                 force_cpu_scatter_add=False):
        super(CompactBilinearPooling, self).__init__()
        self.add_module('sketch1', CountSketch(input1_size, output_size, h1, s1))
        self.add_module('sketch2', CountSketch(input2_size, output_size, h2, s2))
        self.output_size = output_size
        self.force_cpu_scatter_add = force_cpu_scatter_add

    def forward(self, x, y=None):
        if y is None:
            y = x
        return CompactBilinearPoolingFn.apply(self.sketch1.h, self.sketch1.s, self.sketch2.h, self.sketch2.s,
                                              self.output_size, x, y, self.force_cpu_scatter_add)

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
import numpy as np
from functools import partial
from torch.autograd import Function
import math
# -*- coding: utf-8 -*-

import argparse
import os

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import scipy
from scipy import stats
from scipy.optimize import curve_fit
import random
import time


def logistic_func(X, bayta1, bayta2, bayta3, bayta4):
    denominator = np.abs(bayta4) + 1e-5  # to avoid division by zero
    numerator = np.negative(X - bayta3)
    exponent = np.clip(np.divide(numerator, denominator), -500, 500)  # to avoid overflow and underflow
    logisticPart = 1 + np.exp(exponent)
    yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart)
    return yhat


def fit_function(y_label, y_output):
    max_val = np.max(y_label)
    min_val = np.min(y_label)
    mean_val = np.mean(y_output)
    range_val = np.max(y_output) - np.min(y_output)
    beta = [max_val, min_val, mean_val, range_val]
    popt, _ = curve_fit(logistic_func, y_output, y_label, p0=beta, maxfev=1000000)
    y_output_logistic = logistic_func(y_output, *popt)
    return y_output_logistic


class Conv1x1WeightedSum(nn.Module):
    def __init__(self, input_dim_1, input_dim_2, output_dim):
        super(Conv1x1WeightedSum, self).__init__()
        self.conv = nn.Conv1d(input_dim_1 + input_dim_2, output_dim, 1)
        self.global_pool = nn.AdaptiveAvgPool1d(output_size=1)
        self.batch_norm = nn.BatchNorm1d(output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x1, x2):
        # reshape the input tensors to (batch_size, channels, sequence_length)
        x1 = x1.view(x1.size(0), x1.size(1), -1)
        x2 = x2.view(x2.size(0), x2.size(1), -1)
        # concatenate the two feature vectors along the channel dimension
        x = torch.cat((x1, x2), dim=1)
        # apply a 1x1 convolutional layer to generate weights
        weights = self.sigmoid(self.conv(x))
        # compute the weighted sum of the two feature vectors
        output = weights * x1 + (1 - weights) * x2
        # apply global average pooling to the output tensor
        output = self.global_pool(output)
        # apply batch normalization to the output tensor
        output = self.batch_norm(output.squeeze(-1))
        return output


class MultiplyFeatureVectors(nn.Module):
    def __init__(self, in_features):
        super(MultiplyFeatureVectors, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=in_features, out_channels=in_features, kernel_size=1)
        self.global_pool = nn.AdaptiveAvgPool1d(output_size=1)
        self.batch_norm = nn.BatchNorm1d(in_features)

    def forward(self, x1, x2):
        # concatenate the two feature vectors along the channel dimension
        x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)], dim=2)
        # apply a 1D convolutional layer to the concatenated tensor
        output = self.conv1(x)
        # apply global average pooling to the output tensor
        output = self.global_pool(output)
        # apply batch normalization to the output tensor
        output = self.batch_norm(output)
        return output.squeeze()

def global_std_pool2d(x):
    """2D global standard variation pooling"""
    return torch.std(x.view(x.size()[0], x.size()[1], -1, 1),
                     dim=2, keepdim=True)


__all__ = ['myResNet', 'myresnet18', 'myresnet34', 'myresnet50', 'myresnet101',
           'myresnet152', 'myresnext50_32x4d', 'myresnext101_32x8d',
           'mywide_resnet50_2', 'mywide_resnet101_2']

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1
    __constants__ = ['downsample']

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4
    __constants__ = ['downsample']

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class myResNet(nn.Module):

    def __init__(self, block, layers, feature_fusion_method, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(myResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])

        self.bn_img = nn.BatchNorm1d(128)
        self.bn_video = nn.BatchNorm1d(128)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        if feature_fusion_method==0:
            print("[WARN]: Feature Fusion By Concat")
            self.adjust1 = nn.Linear(2048,128)
            self.adjust2 = nn.Linear(256,128)
            self.quality = nn.Linear(128+128,1)
            self.feature_fusion = self.concat_
        elif feature_fusion_method==1:
            print("[WARN]: Feature Fusion By Multiply")
            self.adjust1 = nn.Linear(2048,128)
            self.adjust2 = nn.Linear(256,128)
            self.feature_fusion = self.multiply
            self.fconv = MultiplyFeatureVectors(128)
            self.quality = nn.Linear(128,1)
        elif feature_fusion_method==2:
            print("[WARN]: Feature Fusion By 1x1Conv")
            self.adjust1 = nn.Linear(2048,128)
            self.adjust2 = nn.Linear(256,128)
            self.fconv = Conv1x1WeightedSum(128,128,128)
            self.feature_fusion = self.convolve1x1
            self.quality = nn.Linear(128,1)
        elif feature_fusion_method==3:
            print("[WARN]: Feature Fusion By CompactMultiLinearPooling")
            self.fconv = CompactBilinearPooling(256,256,512)
            self.adjust1 = nn.Linear(2048,256)
            self.bn_img = nn.BatchNorm1d(256)
            self.feature_fusion = self.cmap
            self.quality = nn.Linear(512,1)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def flatten(self, itensors):
        itensors[0] = torch.flatten(itensors[0],1)
        itensors[1] = torch.flatten(itensors[1],1)
        return itensors

    def normalize(self, itensors):
        itensors[0] = self.bn_img((self.adjust1(itensors[0])))
        itensors[1] = self.bn_video(self.adjust2(itensors[1]))
        return itensors

    def concat_(self,itensors,**kwargs):
        itensors = self.flatten(itensors)
        itensors = self.normalize(itensors)
        return torch.cat(tuple(itensors),dim=1)

    def multiply(self,itensors,**kwargs):
        itensors = self.flatten(itensors)
        itensors = self.normalize(itensors)
        return self.fconv(itensors[0],itensors[1])

    def convolve1x1(self,itensors,**kwargs):
        itensors = self.flatten(itensors)
        itensors = self.normalize(itensors)
        return self.fconv(itensors[0],itensors[1])

    def cmap(self,itensors,**kwargs):
      itensors = self.flatten(itensors)
      itensors[0] = self.bn_img((self.adjust1(itensors[0])))
      return self.fconv(itensors[0], itensors[1])

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def quality_pred(self,in_channels,middle_channels,out_channels):
        regression_block = nn.Sequential(
            nn.Linear(in_channels, middle_channels),
            nn.Linear(middle_channels, out_channels),
        )

        return regression_block

    def hyper_structure1(self,in_channels,out_channels):

        hyper_block = nn.Sequential(
            nn.Conv2d(in_channels,in_channels//4,kernel_size=1,stride=1, padding=0,bias=False),
            nn.Conv2d(in_channels//4,in_channels//4,kernel_size=3,stride=1, padding=1,bias=False),
            nn.Conv2d(in_channels//4,out_channels,kernel_size=1,stride=1, padding=0,bias=False),
        )

        return hyper_block

    def hyper_structure2(self,in_channels,out_channels):
        hyper_block = nn.Sequential(
            nn.Conv2d(in_channels,in_channels//4,kernel_size=1,stride=1, padding=0,bias=False),
            nn.Conv2d(in_channels//4,in_channels//4,kernel_size=3,stride=2, padding=1,bias=False),
            nn.Conv2d(in_channels//4,out_channels,kernel_size=1,stride=1, padding=0,bias=False),
        )

        return hyper_block


    def forward(self, x, x_fast_features):
        x_size = x.shape
        x_fast_features_size = x_fast_features.shape
        x = x.view(-1, x_size[2], x_size[3], x_size[4])
        x_fast_features = x_fast_features.view(-1, x_fast_features_size[2])
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.layer4(x)
        x = self.avgpool(x)
        # This is where we would mess around with the feature vector adaptation
        #x = torch.cat((self.bn_img((self.adjust1(x))), self.bn_video(self.adjust2(x_fast_features))), dim=1)
        x = self.feature_fusion([x,x_fast_features], dim=1,x_size=x_size)
        output = self.quality(x)
        output = output.view(x_size[0],x_size[1])
        output = torch.mean(output,dim=1)
        return output

def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = myResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def myresnet18(pretrained=False, progress=True, **kwargs):
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                   **kwargs)


def myresnet34(pretrained=False, progress=True, **kwargs):
    r"""ResNet-34 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = myResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        # model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
        model_dict = model.state_dict()
        pre_train_model = model_zoo.load_url(model_urls['resnet34'])
        pre_train_model = {k:v for k,v in pre_train_model.items() if k in model_dict}
        model_dict.update(pre_train_model)
        model.load_state_dict(model_dict)
    return model


def myresnet50(pretrained=False, progress=True, **kwargs):
    r"""myResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    print(kwargs)
    model = myResNet(Bottleneck, [3, 4, 6, 3], feature_fusion_method=kwargs.pop('feature_fusion_method',0))
    # input = torch.randn(1, 3, 224, 224)
    # flops, params = profile(model, inputs=(input, ))
    # print('The flops is {:.4f}, and the params is {:.4f}'.format(flops/10e9, params/10e6))
    if pretrained:
        # model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
        model_dict = model.state_dict()
        pre_train_model = model_zoo.load_url(model_urls['resnet50'])
        # pre_train_model = torch.load('./base_ckpts/ResNet_mean_std_MTL_epoch_30_accu_0.963589.pth')
        # print (pre_train_model.items())
        pre_train_model = {k:v for k,v in pre_train_model.items() if k in model_dict and not ('branch_' in k)}
        model_dict.update(pre_train_model)
        model.load_state_dict(model_dict)
        print ('load the pretrained model, done！')
    return model


def myresnet101(pretrained=False, progress=True, **kwargs):
    r"""ResNet-101 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    # return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
    #                **kwargs)
    model = myResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model_dict = model.state_dict()
        pre_train_model = model_zoo.load_url(model_urls['resnet101'])
        pre_train_model = {k:v for k,v in pre_train_model.items() if k in model_dict}
        model_dict.update(pre_train_model)
        model.load_state_dict(model_dict)
    return model


def myresnet152(pretrained=False, progress=True, **kwargs):
    r"""myResNet-152 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    # return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
    #                **kwargs)
    model = myResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        model_dict = model.state_dict()
        pre_train_model = model_zoo.load_url(model_urls['resnet152'])
        pre_train_model = {k:v for k,v in pre_train_model.items() if k in model_dict}
        model_dict.update(pre_train_model)
        model.load_state_dict(model_dict)
    return model


def myresnext50_32x4d(pretrained=False, progress=True, **kwargs):
    r"""ResNeXt-50 32x4d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 4
    #return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
       #            pretrained, progress, **kwargs)
    model = myResNet(Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)
    if pretrained:
        model_dict = model.state_dict()
        pre_train_model = model_zoo.load_url(model_urls['resnext50_32x4d'])
        pre_train_model = {k:v for k,v in pre_train_model.items() if k in model_dict}
        model_dict.update(pre_train_model)
        model.load_state_dict(model_dict)
    return model


def myresnext101_32x8d(pretrained=False, progress=True, **kwargs):
    r"""ResNeXt-101 32x8d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 8
    # return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
    #                pretrained, progress, **kwargs)
    model = myResNet(Bottleneck, [3, 4, 23, 3],
                   pretrained, progress, **kwargs)
    if pretrained:
        model_dict = model.state_dict()
        pre_train_model = model_zoo.load_url(model_urls['resnext101_32x8d'])
        pre_train_model = {k:v for k,v in pre_train_model.items() if k in model_dict}
        model_dict.update(pre_train_model)
        model.load_state_dict(model_dict)
    return model


def mywide_resnet50_2(pretrained=False, progress=True, **kwargs):
    r"""Wide ResNet-50-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)


def mywide_resnet101_2(pretrained=False, progress=True, **kwargs):
    r"""Wide ResNet-101-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
                   pretrained, progress, **kwargs)


# Define the DataLoader


In [None]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from torchvision import transforms

import torch
from torch.utils import data


class VideoDataset_NR_image_with_fast_features(data.Dataset):
    """Read data from the original dataset for feature extraction"""
    def __init__(self, data_dir, data_dir_3D , datainfo_path, transform, crop_size, frame_index=1, video_length_read = 4):
        super(VideoDataset_NR_image_with_fast_features, self).__init__()

        # column_names = ['vid_name', 'scene', 'dis_type_level']
        dataInfo = pd.read_csv(datainfo_path, header = 0, sep=',', index_col=False, encoding="utf-8-sig")

        self.video_names = dataInfo['name']
        self.moss = dataInfo['mos']

        self.crop_size = crop_size
        self.data_dir = data_dir
        self.data_dir_3D = data_dir_3D
        self.transform = transform
        self.length = len(self.video_names)
        self.frame_index = frame_index
        self.video_length_read = video_length_read
        self.items = []
        for idx in range(0, len(self.video_names)):
          video_name = self.video_names.iloc[idx]
          frames_dir = os.path.join(self.data_dir, video_name)

          video_channel = 3
          video_height_crop = self.crop_size
          video_width_crop = self.crop_size

          video_length_read = self.video_length_read
          transformed_video = torch.zeros([video_length_read, video_channel, video_height_crop, video_width_crop])

          video_read_index = 0
          for i in range(video_length_read):
              # select the j-th frame every 30 frames
              imge_name = os.path.join(frames_dir, str(self.frame_index+i*30).zfill(3) + '.png')
              if os.path.exists(imge_name):
                  read_frame = Image.open(imge_name)
                  read_frame = read_frame.convert('RGB')
                  read_frame = self.transform(read_frame)
                  transformed_video[i] = read_frame

                  video_read_index += 1
              else:
                  print(imge_name)
                  print('Image do not exist!')

          if video_read_index < video_length_read:
              for j in range(video_read_index, video_length_read):
                  transformed_video[j] = transformed_video[video_read_index-1]

          # read 3D features
          feature_folder_name = os.path.join(self.data_dir_3D, video_name.split('.')[0])
          transformed_feature = torch.zeros([video_length_read, 256])
          for i in range(video_length_read):
              feature_3D = np.load(os.path.join(feature_folder_name, 'feature_' + str(i) + '_fast_feature.npy'))
              feature_3D = torch.from_numpy(feature_3D)
              feature_3D = feature_3D.squeeze()
              transformed_feature[i] = feature_3D

          self.items.append([transformed_video, transformed_feature])

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return *self.items[idx], self.moss.iloc[idx]


# FastAI

In [None]:
%%capture
import sys
import subprocess
import pkg_resources

required = {'nbdev', 'fastbook'}
installed = {pkg.key for pkg in pkg_resources.working_set}
missing = required - installed
print(missing)

if 'nbdev' in missing:
    print("INSTALLING nbdev")
    !pip install nbdev
!pip list |grep nbdev
if 'fastbook' in missing:
    print("INSTALLING fastbook")
    !pip install -Uqq fastbook
!pip list |grep fastbook

from fastai import *
from fastai.vision import *

In [None]:
# Now we SET UP Fastai
import fastbook
fastbook.setup_book()
import fastai
import fastcore
print(f'fastcore version {fastcore.__version__} installed')
print(f'fastai version {fastai.__version__} installed')
from nbdev.showdoc import *
from fastai.vision.all import *

In [None]:
@patch
@delegates(subplots)
def plot_metrics(self: Recorder, nrows=None, ncols=None, figsize=None, fname='resultado.png',**kwargs):
    metrics = np.stack(self.values)
    names = self.metric_names[1:-1]
    n = len(names) - 1
    if nrows is None and ncols is None:
        nrows = int(math.sqrt(n))
        ncols = int(np.ceil(n / nrows))
    elif nrows is None: nrows = int(np.ceil(n / ncols))
    elif ncols is None: ncols = int(np.ceil(n / nrows))
    figsize = figsize or (ncols * 6, nrows * 4)
    fig, axs = subplots(nrows, ncols, figsize=figsize, **kwargs)
    axs = [ax if i < n else ax.set_axis_off() for i, ax in enumerate(axs.flatten())][:n]
    for i, (name, ax) in enumerate(zip(names, [axs[0]] + axs)):
        ax.plot(metrics[:, i], color='#1f77b4' if i == 0 else '#ff7f0e', label='valid' if i > 0 else 'train')
        ax.set_title(name if i > 1 else 'losses')
        ax.legend(loc='best')
    plt.show()
    plt.savefig(fname=fname)

In [None]:
%%capture
!pip install wandb

In [None]:
import wandb
from fastai.callback.wandb import *
wandb.login(relogin=True, key='...')

In [None]:
!cp /content/gdrive/MyDrive/VQA_PC.zip /content
!unzip -q -x VQA_PC.zip

In [None]:
%cd /content/VQA_PC/extraction
!unzip -q -x our_data_features.zip
%cd /content/VQA_PC/rotation
!unzip -q -x ourdata.zip
!unzip -q -x ourdata_resized.zip

In [None]:
%cd /content/VQA_PC/main

# Define the Dataloader function and the base config

In [None]:
from fastai import *
from fastai.vision import *
config = {
    'database': 'SJTU',
    'conv_base_lr': 0.00004,
    'decay_ratio': 0.9,
    'decay_interval': 10,
    'train_batch_size': 32,
    'num_workers': 2,
    'epochs': 30,
    'split_num': 9,
    'crop_size': 224,
    'frame_index': 5,
    'video_length_read': 4,
    'pretrained_model_path': 'ckpts/ResNet_mean_with_fast_LSPCQA_1_best.pth',
    'feature_fusion_method': 0
    }
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

config = dotdict(config)


def get_dataloader(config: dict, split: int = 0):
  transformations_train = transforms.Compose([transforms.RandomCrop(224),transforms.ToTensor(),\
  transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])])
  transformations_test = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),\
  transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])])

  if config.database == 'SJTU':
    images_dir = 'database/sjtu_2d/'
    datainfo_train = 'database/sjtu_data_info/train_' + str(split+1) +'.csv'
    datainfo_test = 'database/sjtu_data_info/test_' + str(split+1) +'.csv'
    data_3d_dir = 'database/sjtu_slowfast/'
  elif config.database == 'SJTU_resized':
    images_dir = 'database/sjtu_2d_resized/'
    datainfo_train = 'database/sjtu_data_info/train_' + str(split+1) +'.csv'
    datainfo_test = 'database/sjtu_data_info/test_' + str(split+1) +'.csv'
    data_3d_dir = 'database/sjtu_slowfast/'
  elif config.database == 'WPC':
    images_dir = 'database/wpc_2d/'
    datainfo_train = 'database/wpc_data_info/train_' + str(split+1) +'.csv'
    datainfo_test = 'database/wpc_data_info/test_' + str(split+1) +'.csv'
    data_3d_dir = 'database/wpc_slowfast/'
  elif 'LS_SJTU' in config.database:
    images_dir = '../rotation/imgs/'
    datainfo_train = 'database/ls_sjtu_data_info/train_' + str(split+1) +'.csv'
    datainfo_test = 'database/ls_sjtu_data_info/test_' + str(split+1) +'.csv'
    data_3d_dir = '../extraction/ls_sjtu_features/'
    if 'SCALED' in config.database:
      datainfo_train = 'database/ls_sjtu_data_info_scaled/train_' + str(split+1) +'.csv'
      datainfo_test = 'database/ls_sjtu_data_info_scaled/test_' + str(split+1) +'.csv'
  elif 'OURDATA' in config.database:
    images_dir = '../rotation/ourimgs/'
    datainfo_train = 'database/our_data_info/train_' + str(split+1) +'.csv'
    datainfo_test = 'database/our_data_info/test_' + str(split+1) +'.csv'
    data_3d_dir = '../extraction/our_data_features/'
    if 'SCALED' in config.database:
      datainfo_train = 'database/our_data_info_scaled/train_' + str(split+1) +'.csv'
      datainfo_test = 'database/our_data_info_scaled/test_' + str(split+1) +'.csv'
    if 'RESIZED' in config.database:
      images_dir = '../rotation/ourimgs_resized/'

  trainset = VideoDataset_NR_image_with_fast_features(images_dir, data_3d_dir, datainfo_train, transformations_train, crop_size=config.crop_size,frame_index=config.frame_index,video_length_read = config.video_length_read)
  testset = VideoDataset_NR_image_with_fast_features(images_dir, data_3d_dir, datainfo_test, transformations_test, crop_size=config.crop_size,frame_index=config.frame_index,video_length_read = config.video_length_read)

  train_loader = DataLoader(trainset, batch_size=config.train_batch_size, n_inp=2,
                            shuffle=False, num_workers=config.num_workers, pin_memory=True)

  val_loader = TfmdDL(testset, batch_size=1, n_inp=2,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True)
  dls =  fastai.data.core.DataLoaders(train_loader, val_loader, device="cuda")
  dls.train.n_inp=2
  # dls.valid.n_inp=2

  return dls

# K-fold + logging

In [None]:
import functools
import gc
def clean_gpu_memory():
  torch.cuda.empty_cache()
  gc.collect()

clean_gpu_memory()

def kfold_model(dls_config: dict, model_func, optimizer, kfold: int = 9, n_epoch: int = 15,
                model_name: str = 'MVQAPC_model', project_name: str = 'MVQAPC',
                group: str = 'quality',
                dataset: str = 'SJTU-PCQA', architecture: str = 'PLAIN',
                leslie: bool = False):
  with wandb.init(
      project=project_name,
      group=group,
      config=
      {
      "architecture": architecture,
      "dataset": dataset,
      "epochs": n_epoch,
      "optimizer": optimizer,
      "leslie": leslie
      },
      dir='/content'
                  ) as run:
    config = wandb.config
    srocc = SpearmanCorrCoef()
    cbs = [EarlyStoppingCallback(monitor='valid_loss', patience=6)]
    metrics = [rmse, srocc]
    dls = get_dataloader(dls_config, config.kfold)
    learn = Learner(dls, model_func(), metrics=metrics,loss_func = MSELossFlat())
    if leslie:
      lr = learn.lr_find(show_plot=False)
      train_func = functools.partial(learn.fit_one_cycle, lr_max=lr[0])
    else:
      train_func = functools.partial(learn.fit, lr=dls_config.conv_base_lr)

    train_func(n_epoch=n_epoch,
                 cbs=cbs+[WandbCallback(log_preds_every_epoch=True,
                                        model_name=f'{model_name}_{config.kfold}',
                                        dataset_name=dataset
                                        )
                                ],
              )
    del dls, learn
    clean_gpu_memory()

In [None]:
def get_pretrained(method: int = 0, path: str = 'ckpts/ResNet_mean_with_fast_LSPCQA_1_best.pth'):
  model_pretrained = myresnet50(pretrained=False,feature_fusion_method=method)
  model_pretrained.load_state_dict(torch.load(path))
  return model_pretrained

#model_fusion_0 = myresnet50(pretrained=True, feature_fusion_method=0)
def get_feature_fusion(method: int = 1):
  model = myresnet50(pretrained=True, feature_fusion_method=method)
  return model

In [None]:
def get_basic_ourdata_config():
  config = {
      'database': 'OURDATA',
      'conv_base_lr': 0.00004,
      'decay_ratio': 0.9,
      'decay_interval': 10,
      'train_batch_size': 32,
      'num_workers': 2,
      'epochs': 30,
      'split_num': 11,
      'crop_size': 224,
      'frame_index': 5,
      'video_length_read': 4,
      'pretrained_model_path': 'ckpts/ResNet_mean_with_fast_LSPCQA_1_best.pth',
      'feature_fusion_method': 0
      }
  config = dotdict(config)
  return config

In [None]:
def get_sweep_config(split_num):
  PROJECT_NAME = 'MVQAPC'

  sweep_config = {
      'method': 'grid'
      }
  metric = {
      'name': 'rmse',
      'goal': 'minimize'
      }

  sweep_config['metric'] = metric
  parameters_dict = {
      'kfold': {
          'values': [i for i in range(split_num)]
          },
      }
  sweep_config['parameters'] = parameters_dict
  return sweep_config

# Replicando resultados sobre SJTU

# Starting executing stuff

In [None]:
import functools
model_func = functools.partial(myresnet50,pretrained=True,feature_fusion_method=0)
body = dict(
  dls_config=config,
  model_func=model_func,
  optimizer=Adam,
  model_name='Scratch_SJTU',
  dataset='SJTU',
  n_epoch=30,
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
import functools
model_func = functools.partial(myresnet50,pretrained=True,feature_fusion_method=0)
config.update({'database': 'SJTU_resized'})
body = dict(
  dls_config=config,
  model_func=model_func,
  optimizer=Adam,
  model_name='Scratch_SJTU_resized',
  dataset='SJTU_resized',
  n_epoch=30,
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

# Comparación entre modelos sobre OurData


In [None]:
%cd /content/VQA_PC/main

## Pretrained SJTU

In [None]:
config = get_basic_ourdata_config()
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC',
  dataset='OURDATA',
  n_epoch=30,
  leslie=False
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_SCALED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_scaled',
  dataset='OURDATA_SCALED',
  group='ScaledLabels',
  n_epoch=30,
  leslie=False
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_SCALED_RESIZED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_scaled',
  dataset='OURDATA_SCALED_RESIZED',
  group='ResizedScaledLabels',
  n_epoch=30,
  leslie=False
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_RESIZED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_RESIZED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

## Feature Fusion 0

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_RESIZED'
body = dict(
  dls_config=config,
  model_func=functools.partial(get_feature_fusion, 0),
  optimizer=Adam,
  model_name='MVQAPC_resized',
  dataset='OURDATA_RESIZED',
  group='FeatureFusion0',
  n_epoch=30,
  leslie=False
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

## Feature Fusion 1

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_RESIZED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_RESIZED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(model_func=functools.partial(get_feature_fusion, 1),
                 model_name='FeatureFusion1_MVQAPC', group='FeatureFusion1'))
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_SCALED_RESIZED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_RS',
  dataset='OURDATA_SCALED_RESIZED',
  group='ResizedScaledLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(model_func=functools.partial(get_feature_fusion, 1),
                 model_name='FeatureFusion1_MVQAPC', group='FeatureFusion1'))
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(model_func=functools.partial(get_feature_fusion, 1),
                 model_name='FeatureFusion1_MVQAPC', group='FeatureFusion1'))
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_SCALED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_SCALED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(model_func=functools.partial(get_feature_fusion, 1),
                 model_name='FeatureFusion1_MVQAPC', group='FeatureFusion1'))
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

## Feature Fusion 2

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_RESIZED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_RESIZED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)config.database = 'OURDATA_RESIZED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_RESIZED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(model_func=functools.partial(get_feature_fusion, 2), model_name='FeatureFusion2_MVQAPC', group='FeatureFusion2'))
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_SCALED_RESIZED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_SCALED_RESIZED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(model_func=functools.partial(get_feature_fusion, 2), model_name='FeatureFusion2_MVQAPC', group='FeatureFusion2'))
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(model_func=functools.partial(get_feature_fusion, 2), model_name='FeatureFusion2_MVQAPC', group='FeatureFusion2'))
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_SCALED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_SCALED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(model_func=functools.partial(get_feature_fusion, 2), model_name='FeatureFusion2_MVQAPC', group='FeatureFusion2'))
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

## Feature fusion 3

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_RESIZED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_RESIZED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(model_func=functools.partial(get_feature_fusion, 3), model_name='FeatureFusion3_MVQAPC', group='FeatureFusion3'))
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_SCALED_RESIZED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_SCALED_RESIZED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(model_func=functools.partial(get_feature_fusion, 3), model_name='FeatureFusion3_MVQAPC', group='FeatureFusion3'))
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(model_func=functools.partial(get_feature_fusion, 3), model_name='FeatureFusion3_MVQAPC', group='FeatureFusion3'))
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_SCALED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_SCALED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(model_func=functools.partial(get_feature_fusion, 3), model_name='FeatureFusion3_MVQAPC', group='FeatureFusion3'))
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

# pretrain in LS-SJTU-PCQA

In [None]:
%cd /content/VQA_PC/extraction
!unzip -q -x ls_sjtu_features.zip
%cd /content/VQA_PC/rotation
!unzip -q -x LS-SJTU-PCQA.zip

In [None]:
%cd /content/VQA_PC/main

## No Leslie

In [None]:
config = get_basic_ourdata_config()
config.database = 'LS_SJTU_SCALED'
model = get_feature_fusion(0)
dls = get_dataloader(config, 0)
srocc = SpearmanCorrCoef()
cbs = [EarlyStoppingCallback(monitor='spearmanr', patience=6), SaveModelCallback(monitor='spearmanr', fname='feature0')]
metrics = [rmse, srocc]
learn = Learner(dls, model,  metrics=metrics, loss_func=MSELossFlat(), path='/content/gdrive/MyDrive/')

In [None]:
learn.fit(30, lr=config.conv_base_lr, cbs=cbs)

In [None]:
config = get_basic_ourdata_config()
config.database = 'LS_SJTU_SCALED'
model = get_feature_fusion(1)
dls = get_dataloader(config, 0)
srocc = SpearmanCorrCoef()
cbs = [EarlyStoppingCallback(monitor='spearmanr', patience=6), SaveModelCallback(monitor='spearmanr', fname='feature1')]
metrics = [rmse, srocc]
learn = Learner(dls, model,  metrics=metrics, loss_func=MSELossFlat(), path='/content/gdrive/MyDrive/')
learn.fit(30, lr=config.conv_base_lr, cbs=cbs)

In [None]:
config = get_basic_ourdata_config()
config.database = 'LS_SJTU_SCALED'
model = get_feature_fusion(2)
dls = get_dataloader(config, 0)
srocc = SpearmanCorrCoef()
cbs = [EarlyStoppingCallback(monitor='spearmanr', patience=6), SaveModelCallback(monitor='spearmanr', fname='feature2')]
metrics = [rmse, srocc]
learn = Learner(dls, model,  metrics=metrics, loss_func=MSELossFlat(), path='/content/gdrive/MyDrive/')
learn.fit(30, lr=config.conv_base_lr, cbs=cbs)

In [None]:
config = get_basic_ourdata_config()
config.database = 'LS_SJTU_SCALED'
model = get_feature_fusion(3)
dls = get_dataloader(config, 0)
srocc = SpearmanCorrCoef()
cbs = [EarlyStoppingCallback(monitor='spearmanr', patience=6), SaveModelCallback(monitor='spearmanr', fname='feature3')]
metrics = [rmse, srocc]
learn = Learner(dls, model,  metrics=metrics, loss_func=MSELossFlat(), path='/content/gdrive/MyDrive/')
learn.fit(30, lr=config.conv_base_lr, cbs=cbs)

# Fine Tune From LS-SJTU-PCQA

In [None]:
import polars as pl
import glob
import os

objs = glob.glob('/content/VQA_PC/main/database/our_data_info_norm/*.csv')
print(objs)
for obj in objs:
  df = pl.read_csv(obj)
  name = os.path.basename(obj)
  df = df.select(pl.col('name'),pl.col('mos').apply(lambda x: 1-x).mul(5))
  df.write_csv(f'/content/VQA_PC/main/database/our_data_info_scaled/{name}')

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_RESIZED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_RESIZED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(
    model_func=functools.partial(get_pretrained,
                                 method=0,
                                 path='/content/gdrive/MyDrive/models/feature0.pth'
                                 ), model_name='LS-SJTU-F0-v3',
    group='LS-SJTU-F0-v3')
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_RESIZED_SCALED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_RESIZED_SCALED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(
    model_func=functools.partial(get_pretrained,
                                 method=0,
                                 path='/content/gdrive/MyDrive/models/feature0.pth'
                                 ), model_name='LS-SJTU-F0-v3',
    group='LS-SJTU-F0-v3')
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_RESIZED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_RESIZED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(
    model_func=functools.partial(get_pretrained,
                                 method=1,
                                 path='/content/gdrive/MyDrive/models/feature1.pth'
                                 ), model_name='LS-SJTU-F1-v3',
    group='LS-SJTU-F1-v3')
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_RESIZED_SCALED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_RESIZED_SCALED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(
    model_func=functools.partial(get_pretrained,
                                 method=1,
                                 path='/content/gdrive/MyDrive/models/feature1.pth'
                                 ), model_name='LS-SJTU-F1-v3',
    group='LS-SJTU-F1-v3')
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_RESIZED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_RESIZED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(
    model_func=functools.partial(get_pretrained,
                                 method=2,
                                 path='/content/gdrive/MyDrive/models/feature2.pth'
                                 ), model_name='LS-SJTU-F2-v3',
    group='LS-SJTU-F2-v3')
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_RESIZED_SCALED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_RESIZED_SCALED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(
    model_func=functools.partial(get_pretrained,
                                 method=2,
                                 path='/content/gdrive/MyDrive/models/feature2.pth'
                                 ), model_name='LS-SJTU-F2-v3',
    group='LS-SJTU-F2-v3')
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_RESIZED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_RESIZED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(
    model_func=functools.partial(get_pretrained,
                                 method=3,
                                 path='/content/gdrive/MyDrive/models/feature3.pth'
                                 ), model_name='LS-SJTU-F3-v3',
    group='LS-SJTU-F3-v3')
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

## Leslie + LS-SJTU-PCQA

In [None]:
config= get_basic_ourdata_config()
config.database = 'LS_SJTU_SCALED'
model = get_feature_fusion(0)
dls = get_dataloader(config, 0)
srocc = SpearmanCorrCoef()
cbs = [EarlyStoppingCallback(monitor='spearmanr', patience=6), SaveModelCallback(monitor='spearmanr', fname='Lesliefeature0')]
metrics = [rmse, srocc]
learn = Learner(dls, model,  metrics=metrics, loss_func=MSELossFlat(), path='/content/gdrive/MyDrive/')
lr = learn.lr_find(show_plot=False)
learn.fit_one_cycle(30, lr_max=lr[0], cbs=cbs)

In [None]:
config= get_basic_ourdata_config()
config.database = 'LS_SJTU_SCALED'
model = get_feature_fusion(1)
dls = get_dataloader(config, 0)
srocc = SpearmanCorrCoef()
cbs = [EarlyStoppingCallback(monitor='spearmanr', patience=6), SaveModelCallback(monitor='spearmanr', fname='Lesliefeature1')]
metrics = [rmse, srocc]
learn = Learner(dls, model,  metrics=metrics, loss_func=MSELossFlat(), path='/content/gdrive/MyDrive/')
lr = learn.lr_find(show_plot=False)
learn.fit_one_cycle(30, lr_max=lr[0], cbs=cbs)

In [None]:
config= get_basic_ourdata_config()
config.database = 'LS_SJTU_SCALED'
model = get_feature_fusion(2)
dls = get_dataloader(config, 0)
srocc = SpearmanCorrCoef()
cbs = [EarlyStoppingCallback(monitor='spearmanr', patience=6), SaveModelCallback(monitor='spearmanr', fname='Lesliefeature2')]
metrics = [rmse, srocc]
learn = Learner(dls, model,  metrics=metrics, loss_func=MSELossFlat(), path='/content/gdrive/MyDrive/')
lr = learn.lr_find(show_plot=False)
learn.fit_one_cycle(30, lr_max=lr[0], cbs=cbs)

In [None]:
config= get_basic_ourdata_config()
config.database = 'LS_SJTU_SCALED'
model = get_feature_fusion(3)
dls = get_dataloader(config, 0)
srocc = SpearmanCorrCoef()
cbs = [EarlyStoppingCallback(monitor='spearmanr', patience=6), SaveModelCallback(monitor='spearmanr', fname='Lesliefeature3')]
metrics = [rmse, srocc]
learn = Learner(dls, model,  metrics=metrics, loss_func=MSELossFlat(), path='/content/gdrive/MyDrive/')
lr = learn.lr_find(show_plot=False)
learn.fit_one_cycle(30, lr_max=lr[0], cbs=cbs)

## Leslie Finetune OurData

In [None]:
import polars as pl
import glob
import os

objs = glob.glob('/content/VQA_PC/main/database/our_data_info_norm/*.csv')
print(objs)
for obj in objs:
  df = pl.read_csv(obj)
  name = os.path.basename(obj)
  df = df.select(pl.col('name'),pl.col('mos').apply(lambda x: 1-x).mul(5))
  df.write_csv(f'/content/VQA_PC/main/database/our_data_info_scaled/{name}')

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_RESIZED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_RESIZED_SCALED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(
    model_func=functools.partial(get_pretrained,
                                 method=0,
                                 path='/content/gdrive/MyDrive/models/Lesliefeature0.pth'
                                 ), model_name='LS-SJTU-LESLIE-F0-v2',
    group='LS-SJTU-LESLIE-F0-v2')
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_RESIZED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_RESIZED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(
    model_func=functools.partial(get_pretrained,
                                 method=1,
                                 path='/content/gdrive/MyDrive/models/Lesliefeature1.pth'
                                 ), model_name='LS-SJTU-LESLIE-F1-v2',
    group='LS-SJTU-LESLIE-F1-v2')
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_RESIZED_SCALED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_RESIZED_SCALED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(
    model_func=functools.partial(get_pretrained,
                                 method=1,
                                 path='/content/gdrive/MyDrive/models/Lesliefeature1.pth'
                                 ), model_name='LS-SJTU-LESLIE-F1-v2',
    group='LS-SJTU-LESLIE-F1-v2')
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_RESIZED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_RESIZED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(
    model_func=functools.partial(get_pretrained,
                                 method=2,
                                 path='/content/gdrive/MyDrive/models/Lesliefeature2.pth'
                                 ), model_name='LS-SJTU-LESLIE-F2-v2',
    group='LS-SJTU-LESLIE-F2-v2')
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

In [None]:
config = get_basic_ourdata_config()
config.database = 'OURDATA_RESIZED'
body = dict(
  dls_config=config,
  model_func=get_pretrained,
  optimizer=Adam,
  model_name='Pretrained_MVQAPC_resized',
  dataset='OURDATA_RESIZED',
  group='ResizedLabels',
  n_epoch=30,
  leslie=False
)
body.update(dict(
    model_func=functools.partial(get_pretrained,
                                 method=3,
                                 path='/content/gdrive/MyDrive/models/Lesliefeature3.pth'
                                 ), model_name='LS-SJTU-LESLIE-F3-v2',
    group='LS-SJTU-LESLIE-F3-v2')
)
sweep_config = get_sweep_config(config.split_num)
polysweep_train = functools.partial(kfold_model, **body)
sweep_id = wandb.sweep(sweep=sweep_config, project='MVQAPC')
wandb.agent(sweep_id,
            polysweep_train, count=config.split_num)

# Before Quitting

In [None]:
from google.colab import drive
drive.flush_and_unmount()