In [None]:
class WRN(nn.Module):
    """
    Flexibly sized Wide Residual Network (WRN). Extended to the variational setting.
    """
    def __init__(self, device, num_classes, num_colors, args):
        super(WRN, self).__init__()
        
        self.encoder = make_layers([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],batch_norm=True)
        
        # self.encoder = nn.Sequential(OrderedDict([
        #     ('encoder_conv1', nn.Conv2d(num_colors, self.nChannels[0], kernel_size=3, stride=1, padding=1, bias=False)),
        #     ('encoder_block1', WRNNetworkBlock(self.num_block_layers, self.nChannels[0], self.nChannels[1],
        #                                        WRNBasicBlock, batchnorm=self.batch_norm, dropout=self.dropout)),
        #     ('encoder_block2', WRNNetworkBlock(self.num_block_layers, self.nChannels[1], self.nChannels[2],
        #                                        WRNBasicBlock, batchnorm=self.batch_norm, stride=2,
        #                                        dropout=self.dropout)),
        #     ('encoder_block3', WRNNetworkBlock(self.num_block_layers, self.nChannels[2], self.nChannels[3],
        #                                        WRNBasicBlock, batchnorm=self.batch_norm, stride=2,
        #                                        dropout=self.dropout)),
        #     ('encoder_bn1', nn.BatchNorm2d(self.nChannels[3], eps=self.batch_norm)),
        #     ('encoder_act1', nn.ReLU(inplace=True))
        # ]))

        self.enc_channels, self.enc_spatial_dim_x, self.enc_spatial_dim_y = get_feat_size(self.encoder, self.patch_size,self.num_colors)
        if self.variational:
            #print("True it's in the args variational")
            self.latent_mu = nn.Linear(self.enc_spatial_dim_x * self.enc_spatial_dim_x * self.enc_channels,self.latent_dim, bias=False)
            self.latent_std = nn.Linear(self.enc_spatial_dim_x * self.enc_spatial_dim_y * self.enc_channels,self.latent_dim, bias=False)
            self.latent_feat_out = self.latent_dim
        else:
            self.latent_feat_out = self.enc_spatial_dim_x * self.enc_spatial_dim_x * self.enc_channels
            self.latent_dim = self.latent_feat_out
            #print(self.latent_dim)

        if self.joint:
            #print("True came to joint training")
            self.classifier = nn.Sequential(nn.Linear(self.latent_feat_out, num_classes, bias=False))

            if self.variational:
                #print("True came to variational",self.latent_feat_out,self.enc_spatial_dim_x * self.enc_spatial_dim_y *
                                               # self.enc_channels)
                self.latent_decoder = nn.Linear(self.latent_feat_out, self.enc_spatial_dim_x * self.enc_spatial_dim_y *
                                                self.enc_channels, bias=False)

            self.decoder =  nn.Sequential(
            # nn.Linear(1,self.enc_spatial_dim_x * self.enc_spatial_dim_y *
            #                                     self.enc_channels),                           # B, 1024*8*8
            # View((-1, 1024, 4, 4)),                               # B, 1024,  8,  8
            nn.ConvTranspose2d(512, 512, 3, 1, 1, bias=False),   # B,  512, 16, 16
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 3, 1, 1, bias=False),    # B,  256, 32, 32
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 3, 1, 1, bias=False),    # B,  128, 64, 64
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 3, 1),                       # B,   nc, 64, 64
        )
            
        
        else:
            self.classifier = nn.Sequential(nn.Linear(self.latent_feat_out, num_classes, bias=False))

        # self._initialize_weights()

    def encode(self, x):
        
        x = self.encoder(x)
        if self.variational:
            x = x.view(x.size(0), -1)
            z_mean = self.latent_mu(x)
            z_std = self.latent_std(x)
            return z_mean, z_std
        else:
            return x

    def reparameterize(self, mu, std):
        eps = std.data.new(std.size()).normal_()
        return eps.mul(std).add(mu)

    def decode(self, z):
        if self.variational:
            z = self.latent_decoder(z)
            z = z.view(z.size(0), self.enc_channels, self.enc_spatial_dim_x, self.enc_spatial_dim_y)
            #print("shape of z is",z.shape)
        x = self.decoder(z)
        return x

    def generate(self):
        z = torch.randn(self.batch_size, self.latent_dim).to(self.device)
        x = self.decode(z)
        x = torch.sigmoid(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)



    def forward(self, x):
        if self.variational:
            z_mean, z_std = self.encode(x)
            if self.joint:
                output_samples = torch.zeros(self.num_samples, x.size(0), self.num_colors, self.patch_size,
                                             self.patch_size).to(self.device)
                classification_samples = torch.zeros(self.num_samples, x.size(0), self.num_classes).to(self.device)
            else:
                output_samples = torch.zeros(self.num_samples, x.size(0), self.num_classes).to(self.device)
            for i in range(self.num_samples):
                z = self.reparameterize(z_mean, z_std)
                if self.joint:
                    output_samples[i] = self.decode(z)
                    classification_samples[i] = self.classifier(z)
                else:
                    output_samples[i] = self.classifier(z)
            if self.joint:
                return classification_samples, output_samples, z_mean, z_std
            else:
                return output_samples, z_mean, z_std
        else:
            x = self.encode(x)
            if self.joint:
                recon = self.decode(x)
                classification = self.classifier(x.view(x.size(0), -1))
                return classification, recon
            else:
                output = self.classifier(x.view(x.size(0), -1))
            return output

In [1]:
import pickle

In [8]:
with open ('preds', 'rb') as fp:preds = pickle.load(fp)

In [15]:
preds[0].shape#.view(128,10)

torch.Size([128, 10])

In [21]:
import torch
preds=torch.argmax(preds[0], dim=1).cpu().numpy()

In [18]:
with open ('labels', 'rb') as fp:labels = pickle.load(fp)

In [19]:
labels

tensor([3, 8, 8, 0, 6, 6, 1, 6, 3, 1, 0, 9, 5, 7, 9, 8, 5, 7, 8, 6, 7, 0, 4, 9,
        5, 2, 4, 0, 9, 6, 6, 5, 4, 5, 9, 2, 4, 1, 9, 5, 4, 6, 5, 6, 0, 9, 3, 9,
        7, 6, 9, 8, 0, 3, 8, 8, 7, 7, 4, 6, 7, 3, 6, 3, 6, 2, 1, 2, 3, 7, 2, 6,
        8, 8, 0, 2, 9, 3, 3, 8, 8, 1, 1, 7, 2, 5, 2, 7, 8, 9, 0, 3, 8, 6, 4, 6,
        6, 0, 0, 7, 4, 5, 6, 3, 1, 1, 3, 6, 8, 7, 4, 0, 6, 2, 1, 3, 0, 4, 2, 7,
        8, 3, 1, 2, 8, 0, 8, 3])

In [22]:
from sklearn.metrics import accuracy_score

accuracy_score(labels, preds, normalize=False)

13

In [1]:
from torchvision import datasets, transforms
import torch.utils.data as data

In [48]:
test_dataloader =  data.DataLoader(
                datasets.Caltech256('/mnt/iscsi/data/Jay/datasets/', download=True),
            batch_size=128, drop_last=False)
        

Files already downloaded and verified


In [49]:
len(test_dataloader)

240

In [51]:
240*128

30720

In [1]:
from custom_datasets import *
def imagenet_transformer():
    transform=transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                              std=[0.229, 0.224, 0.225])
    ])

In [86]:
# test_dataloader =  data.DataLoader(datasets.Caltech256('/mnt/iscsi/data/Jay/datasets/caltech256/', download=True,transform=imagenet_transformer()),
#             batch_size=128, drop_last=False)
from torchvision import datasets, transforms
import torch.utils.data as data
class Caltech2556(Dataset):
    def __init__(self, path):
        self.caltech256 =  datasets.ImageFolder(root=path,transform=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ]))

    def __getitem__(self, index):
        if isinstance(index, numpy.float64):
            index = index.astype(numpy.int64)
        data, target = self.caltech256[index]

        return data, target, index

    def __len__(self):
        return len(self.caltech256)
    
train_dataset = Caltech2556('/mnt/iscsi/data/Jay/datasets/caltech256/caltech256/train/')

In [87]:
train_dataset.caltech256.classes

['001.ak47',
 '002.american-flag',
 '003.backpack',
 '004.baseball-bat',
 '005.baseball-glove',
 '006.basketball-hoop',
 '007.bat',
 '008.bathtub',
 '009.bear',
 '010.beer-mug',
 '011.billiards',
 '012.binoculars',
 '013.birdbath',
 '014.blimp',
 '015.bonsai-101',
 '016.boom-box',
 '017.bowling-ball',
 '018.bowling-pin',
 '019.boxing-glove',
 '020.brain-101',
 '021.breadmaker',
 '022.buddha-101',
 '023.bulldozer',
 '024.butterfly',
 '025.cactus',
 '026.cake',
 '027.calculator',
 '028.camel',
 '029.cannon',
 '030.canoe',
 '031.car-tire',
 '032.cartman',
 '033.cd',
 '034.centipede',
 '035.cereal-box',
 '036.chandelier-101',
 '037.chess-board',
 '038.chimp',
 '039.chopsticks',
 '040.cockroach',
 '041.coffee-mug',
 '042.coffin',
 '043.coin',
 '044.comet',
 '045.computer-keyboard',
 '046.computer-monitor',
 '047.computer-mouse',
 '048.conch',
 '049.cormorant',
 '050.covered-wagon',
 '051.cowboy-hat',
 '052.crab-101',
 '053.desk-globe',
 '054.diamond-ring',
 '055.dice',
 '056.dog',
 '057.dol

In [88]:
len(train_dataset)

29780

In [91]:
for i, (images, target,_) in enumerate(train_dataset):
    #if i%10==0:print(i,images.shape,target,_)
    print(i,images.shape,target,_)
    sys.exit()

0 torch.Size([3, 224, 224]) 0 0


NameError: name 'sys' is not defined

In [None]:
num_val = 128120
num_images = 30607
budget = 64060
initial_budget = 3000
num_classes = 257

In [64]:
data_path='/mnt/iscsi/data/Jay/datasets/cifar10'

In [85]:
test_dataloader = data.DataLoader(
                datasets.CIFAR10(data_path, download=True, train=False),
            batch_size=128, drop_last=False)
train_dataset = CIFAR10(data_path)

NameError: name 'data_path' is not defined

In [68]:
len(train_dataset)

50000

In [28]:
from easydict import EasyDict
args = EasyDict()
args

{}

In [37]:
args.num_val = 2700
args.num_images = 27607
args.budget = 1530
args.initial_budget = 3060
args.num_classes = 256
args.test=3000

In [38]:
args

{'budget': 1530,
 'initial_budget': 3060,
 'num_classes': 256,
 'num_images': 27607,
 'num_val': 2700,
 'test': 3000}

In [39]:
import numpy as np
import random
all_indices = set(np.arange(args.num_images))
val_indices = random.sample(all_indices, args.num_val)
all_indices = np.setdiff1d(list(all_indices), val_indices)
initial_indices = random.sample(list(all_indices), args.initial_budget)
sampler = data.sampler.SubsetRandomSampler(initial_indices)
val_sampler = data.sampler.SubsetRandomSampler(val_indices)

In [40]:
len(val_indices)

2700

In [67]:
import torch
a=torch.Tensor(2,64,256)

In [68]:
a

tensor([[[3.2678e-41, 0.0000e+00, 3.2680e-41,  ..., 0.0000e+00,
          3.2856e-41, 0.0000e+00],
         [3.2858e-41, 0.0000e+00, 3.2859e-41,  ..., 0.0000e+00,
          3.3036e-41, 0.0000e+00],
         [3.3037e-41, 0.0000e+00, 3.3038e-41,  ..., 0.0000e+00,
          3.3215e-41, 0.0000e+00],
         ...,
         [4.9326e-42, 0.0000e+00, 4.9340e-42,  ..., 0.0000e+00,
          5.1105e-42, 0.0000e+00],
         [5.1119e-42, 0.0000e+00, 5.1133e-42,  ..., 0.0000e+00,
          5.2899e-42, 0.0000e+00],
         [5.2913e-42, 0.0000e+00, 5.2927e-42,  ..., 0.0000e+00,
          5.4693e-42, 0.0000e+00]],

        [[5.4707e-42, 0.0000e+00, 5.4721e-42,  ..., 0.0000e+00,
          5.6486e-42, 0.0000e+00],
         [5.6500e-42, 0.0000e+00, 5.6514e-42,  ..., 0.0000e+00,
          5.8280e-42, 0.0000e+00],
         [5.8294e-42, 0.0000e+00, 5.8308e-42,  ..., 0.0000e+00,
          6.0074e-42, 0.0000e+00],
         ...,
         [1.6412e-41, 0.0000e+00, 1.6413e-41,  ..., 0.0000e+00,
          1.659

In [72]:
z

tensor([[[3.2678e-41, 0.0000e+00, 3.2680e-41,  ..., 0.0000e+00,
          3.2856e-41, 0.0000e+00],
         [3.2858e-41, 0.0000e+00, 3.2859e-41,  ..., 0.0000e+00,
          3.3036e-41, 0.0000e+00],
         [3.3037e-41, 0.0000e+00, 3.3038e-41,  ..., 0.0000e+00,
          3.3215e-41, 0.0000e+00],
         ...,
         [1.6412e-41, 0.0000e+00, 1.6413e-41,  ..., 0.0000e+00,
          1.6590e-41, 0.0000e+00],
         [1.6591e-41, 0.0000e+00, 1.6593e-41,  ..., 0.0000e+00,
          1.6769e-41, 0.0000e+00],
         [1.6771e-41, 0.0000e+00, 1.6772e-41,  ..., 0.0000e+00,
          1.6949e-41, 0.0000e+00]]])

In [62]:
c=torch.Tensor(1,128,256)

In [63]:
b=torch.cat((a[0], a[1]))

In [71]:
z=a.view_as(c)

In [65]:
b.shape

torch.Size([128, 256])

In [74]:
import pickle
with open ('output_samples_', 'rb') as fp:output_samples_ = pickle.load(fp)

In [84]:
len(output_samples_[0][0])

256

In [78]:
with open ('target', 'rb') as fp:target = pickle.load(fp)

In [79]:
target

tensor([ 82, 256, 229,  47, 190,  68, 239,  54,  18,  12, 112, 176,   1, 125,
          4, 199, 141, 202, 224, 256, 250,   3, 171, 223,  50, 147,  80, 180,
        190,  43, 109,   9, 252, 104,  28,  69,  73, 207,  57,  69, 202, 116,
         57, 256, 229,  10, 211,  41,  44, 215,  91, 106,  43, 152, 223, 236,
        189, 240, 148,  93, 118,  21, 232, 241, 242, 231, 163, 142, 252,  24,
         89, 139, 144, 157,  97, 185, 142, 159, 209, 156, 190, 124,  91, 113,
        104, 239, 153,  63,  86,  97,  37,  33,  33,   1,  60, 151,  94, 254,
         29, 245,  82, 195, 256,  52,  38,  76, 253, 171, 144, 117, 213, 109,
        126, 213, 162, 155, 215, 178, 104,  86, 243,  27,  22,  76, 123, 180,
         49, 243], device='cuda:0')