In [None]:
class WRN(nn.Module):
    """
    Flexibly sized Wide Residual Network (WRN). Extended to the variational setting.
    """
    def __init__(self, device, num_classes, num_colors, args):
        super(WRN, self).__init__()
        
        self.encoder = make_layers([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],batch_norm=True)
        
        # self.encoder = nn.Sequential(OrderedDict([
        #     ('encoder_conv1', nn.Conv2d(num_colors, self.nChannels[0], kernel_size=3, stride=1, padding=1, bias=False)),
        #     ('encoder_block1', WRNNetworkBlock(self.num_block_layers, self.nChannels[0], self.nChannels[1],
        #                                        WRNBasicBlock, batchnorm=self.batch_norm, dropout=self.dropout)),
        #     ('encoder_block2', WRNNetworkBlock(self.num_block_layers, self.nChannels[1], self.nChannels[2],
        #                                        WRNBasicBlock, batchnorm=self.batch_norm, stride=2,
        #                                        dropout=self.dropout)),
        #     ('encoder_block3', WRNNetworkBlock(self.num_block_layers, self.nChannels[2], self.nChannels[3],
        #                                        WRNBasicBlock, batchnorm=self.batch_norm, stride=2,
        #                                        dropout=self.dropout)),
        #     ('encoder_bn1', nn.BatchNorm2d(self.nChannels[3], eps=self.batch_norm)),
        #     ('encoder_act1', nn.ReLU(inplace=True))
        # ]))

        self.enc_channels, self.enc_spatial_dim_x, self.enc_spatial_dim_y = get_feat_size(self.encoder, self.patch_size,self.num_colors)
        if self.variational:
            #print("True it's in the args variational")
            self.latent_mu = nn.Linear(self.enc_spatial_dim_x * self.enc_spatial_dim_x * self.enc_channels,self.latent_dim, bias=False)
            self.latent_std = nn.Linear(self.enc_spatial_dim_x * self.enc_spatial_dim_y * self.enc_channels,self.latent_dim, bias=False)
            self.latent_feat_out = self.latent_dim
        else:
            self.latent_feat_out = self.enc_spatial_dim_x * self.enc_spatial_dim_x * self.enc_channels
            self.latent_dim = self.latent_feat_out
            #print(self.latent_dim)

        if self.joint:
            #print("True came to joint training")
            self.classifier = nn.Sequential(nn.Linear(self.latent_feat_out, num_classes, bias=False))

            if self.variational:
                #print("True came to variational",self.latent_feat_out,self.enc_spatial_dim_x * self.enc_spatial_dim_y *
                                               # self.enc_channels)
                self.latent_decoder = nn.Linear(self.latent_feat_out, self.enc_spatial_dim_x * self.enc_spatial_dim_y *
                                                self.enc_channels, bias=False)

            self.decoder =  nn.Sequential(
            # nn.Linear(1,self.enc_spatial_dim_x * self.enc_spatial_dim_y *
            #                                     self.enc_channels),                           # B, 1024*8*8
            # View((-1, 1024, 4, 4)),                               # B, 1024,  8,  8
            nn.ConvTranspose2d(512, 512, 3, 1, 1, bias=False),   # B,  512, 16, 16
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 3, 1, 1, bias=False),    # B,  256, 32, 32
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 3, 1, 1, bias=False),    # B,  128, 64, 64
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 3, 1),                       # B,   nc, 64, 64
        )
            
        
        else:
            self.classifier = nn.Sequential(nn.Linear(self.latent_feat_out, num_classes, bias=False))

        # self._initialize_weights()

    def encode(self, x):
        
        x = self.encoder(x)
        if self.variational:
            x = x.view(x.size(0), -1)
            z_mean = self.latent_mu(x)
            z_std = self.latent_std(x)
            return z_mean, z_std
        else:
            return x

    def reparameterize(self, mu, std):
        eps = std.data.new(std.size()).normal_()
        return eps.mul(std).add(mu)

    def decode(self, z):
        if self.variational:
            z = self.latent_decoder(z)
            z = z.view(z.size(0), self.enc_channels, self.enc_spatial_dim_x, self.enc_spatial_dim_y)
            #print("shape of z is",z.shape)
        x = self.decoder(z)
        return x

    def generate(self):
        z = torch.randn(self.batch_size, self.latent_dim).to(self.device)
        x = self.decode(z)
        x = torch.sigmoid(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)



    def forward(self, x):
        if self.variational:
            z_mean, z_std = self.encode(x)
            if self.joint:
                output_samples = torch.zeros(self.num_samples, x.size(0), self.num_colors, self.patch_size,
                                             self.patch_size).to(self.device)
                classification_samples = torch.zeros(self.num_samples, x.size(0), self.num_classes).to(self.device)
            else:
                output_samples = torch.zeros(self.num_samples, x.size(0), self.num_classes).to(self.device)
            for i in range(self.num_samples):
                z = self.reparameterize(z_mean, z_std)
                if self.joint:
                    output_samples[i] = self.decode(z)
                    classification_samples[i] = self.classifier(z)
                else:
                    output_samples[i] = self.classifier(z)
            if self.joint:
                return classification_samples, output_samples, z_mean, z_std
            else:
                return output_samples, z_mean, z_std
        else:
            x = self.encode(x)
            if self.joint:
                recon = self.decode(x)
                classification = self.classifier(x.view(x.size(0), -1))
                return classification, recon
            else:
                output = self.classifier(x.view(x.size(0), -1))
            return output

In [1]:
import pickle

In [8]:
with open ('preds', 'rb') as fp:preds = pickle.load(fp)

In [15]:
preds[0].shape#.view(128,10)

torch.Size([128, 10])

In [21]:
import torch
preds=torch.argmax(preds[0], dim=1).cpu().numpy()

In [18]:
with open ('labels', 'rb') as fp:labels = pickle.load(fp)

In [19]:
labels

tensor([3, 8, 8, 0, 6, 6, 1, 6, 3, 1, 0, 9, 5, 7, 9, 8, 5, 7, 8, 6, 7, 0, 4, 9,
        5, 2, 4, 0, 9, 6, 6, 5, 4, 5, 9, 2, 4, 1, 9, 5, 4, 6, 5, 6, 0, 9, 3, 9,
        7, 6, 9, 8, 0, 3, 8, 8, 7, 7, 4, 6, 7, 3, 6, 3, 6, 2, 1, 2, 3, 7, 2, 6,
        8, 8, 0, 2, 9, 3, 3, 8, 8, 1, 1, 7, 2, 5, 2, 7, 8, 9, 0, 3, 8, 6, 4, 6,
        6, 0, 0, 7, 4, 5, 6, 3, 1, 1, 3, 6, 8, 7, 4, 0, 6, 2, 1, 3, 0, 4, 2, 7,
        8, 3, 1, 2, 8, 0, 8, 3])

In [22]:
from sklearn.metrics import accuracy_score

accuracy_score(labels, preds, normalize=False)

13

In [24]:
from torchvision import datasets, transforms
import torch.utils.data as data

In [48]:
test_dataloader =  data.DataLoader(
                datasets.Caltech256('/mnt/iscsi/data/Jay/datasets/', download=True),
            batch_size=128, drop_last=False)
        

Files already downloaded and verified


In [49]:
len(test_dataloader)

240

In [51]:
240*128

30720

In [69]:
from custom_datasets import *


In [83]:
test_dataloader =  data.DataLoader(
                datasets.Caltech256('/mnt/iscsi/data/Jay/datasets/caltech256/', download=True),
            batch_size=128, drop_last=False)
train_dataset = Caltech256('/mnt/iscsi/data/Jay/datasets/caltech256/')

Files already downloaded and verified


In [85]:
len(train_dataset)

30607

In [None]:
num_val = 128120
num_images = 30607
budget = 64060
initial_budget = 3000
num_classes = 257

In [64]:
data_path='/mnt/iscsi/data/Jay/datasets/cifar10'

In [66]:
test_dataloader = data.DataLoader(
                datasets.CIFAR10(data_path, download=True, train=False),
            batch_size=128, drop_last=False)
train_dataset = CIFAR10(data_path)

Files already downloaded and verified
Files already downloaded and verified


In [68]:
len(train_dataset)

50000