In [None]:
!pip install datasets

In [4]:
import PIL
import pandas as pd
from PIL import Image
from matplotlib import pyplot as plt
from scipy.spatial import distance
import numpy as np

import cv2
from google.colab.patches import cv2_imshow
import matplotlib.image as pltimg
from sklearn.preprocessing import RobustScaler, StandardScaler, MinMaxScaler

import torch
from torch import optim, nn
from torchvision import models, transforms

In [58]:
from datasets import load_dataset
dataset = load_dataset('cifar100', split='train')

Reusing dataset cifar100 (/root/.cache/huggingface/datasets/cifar100/cifar100/1.0.0/0f9be8dd0480d385177a5c250878f4480651bbf0fc86d714b33d56c9aaad5160)


In [59]:
len(dataset)

50000

In [60]:
class FeatureExtractor(nn.Module):
  def __init__(self, model):
    super(FeatureExtractor, self).__init__()
		# Extract VGG-16 Feature Layers
    self.features = list(model.features)
    self.features = nn.Sequential(*self.features)
		# Extract VGG-16 Average Pooling Layer
    self.pooling = model.avgpool
		# Convert the image into one-dimensional vector
    self.flatten = nn.Flatten()
		# Extract the first part of fully-connected layer from VGG16
    self.fc = nn.Linear(in_features=25088, out_features=100, bias=True)
  
  def forward(self, x):
		# It will take the input 'x' until it returns the feature vector called 'out'
    out = self.features(x)
    out = self.pooling(out)
    out = self.flatten(out)
    out = self.fc(out) 
    return out 

# Initialize the model
model = models.vgg16(pretrained=True)
new_model = FeatureExtractor(model)

# Change the device to GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
new_model = new_model.to(device)

In [61]:
from tqdm import tqdm
import numpy as np
import cv2

# Transform the image, so it becomes readable with the model
transform = transforms.Compose([
  transforms.ToPILImage(),
  transforms.CenterCrop(512),
  transforms.Resize(448),
  transforms.ToTensor()                              
])

# Will contain the feature
features = []

# Iterate each image
for imag in tqdm(dataset['img']):
  # Set the image path
  # path = os.path.join('data', 'test', str(i) + '.jpg')
  # Read the file
  try:
    img = np.array(Image.fromarray(np.array(imag), 'RGB'))
    # Transform the image
    img = transform(img)
    # Reshape the image. PyTorch model reads 4-dimensional tensor
    # [batch_size, channels, width, height]
    img = img.reshape(1, 3, 448, 448)
    img = img.to(device)
    # We only extract features, so we don't need gradient
    with torch.no_grad():
      # Extract the feature from the image
      feature = new_model(img)
    # Convert to NumPy Array, Reshape it, and save it to features variable
    features.append(feature.cpu().detach().numpy().reshape(-1))
  except:
    print("OOPS.")
# Convert to NumPy Array
features = np.array(features)

100%|██████████| 50000/50000 [16:07<00:00, 51.68it/s]


In [63]:
train_feat_df = pd.DataFrame(features)
train_feat_df['class'] = dataset['fine_label']

In [64]:
train_feat_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,class
0,-0.045371,-0.163352,-0.051378,0.004362,0.027893,0.042137,-0.001304,-0.014094,-0.124964,-0.080679,-0.001682,-0.039629,0.035048,0.137121,-0.010952,0.008534,0.131559,-0.112238,0.013626,-0.016047,-0.076367,0.048996,-0.013941,-0.005744,-0.020101,-0.064818,-0.051215,0.016299,-0.070055,0.082759,0.052017,0.011038,-0.000432,0.103141,-0.112850,0.001230,-0.049395,0.017621,-0.020865,0.063921,...,-0.033187,-0.036532,-0.006328,-0.067447,-0.068509,0.079203,-0.000958,0.076199,-0.019641,0.145730,-0.001793,-0.078854,0.049385,0.101077,-0.054477,-0.038497,0.113558,0.051526,0.049739,0.054262,-0.013196,0.086547,0.032984,-0.036034,0.029471,-0.011078,0.043944,0.016925,-0.054256,-0.029285,0.035039,0.031290,0.055443,0.035477,0.007479,-0.011093,-0.002466,-0.004473,0.099756,19
1,-0.048349,-0.159363,-0.052219,0.007523,0.025379,0.032219,-0.003062,-0.023806,-0.126761,-0.085429,0.003677,-0.028680,0.034648,0.145837,-0.016032,0.006478,0.137036,-0.110687,0.023536,-0.008214,-0.074442,0.052551,-0.018330,-0.006017,-0.016035,-0.068557,-0.042975,0.024104,-0.073598,0.089826,0.051186,0.014123,0.002373,0.102102,-0.115084,-0.002718,-0.043838,0.021216,-0.024443,0.057783,...,-0.035865,-0.038850,-0.013313,-0.073054,-0.067324,0.080841,0.005542,0.081909,-0.026209,0.147555,0.000746,-0.088493,0.047297,0.102539,-0.049697,-0.032526,0.113181,0.058695,0.055961,0.048006,-0.019411,0.090791,0.028302,-0.034889,0.031250,-0.010302,0.048454,0.028517,-0.065781,-0.027154,0.043333,0.025249,0.065864,0.047645,0.014432,-0.020838,-0.003025,0.001974,0.106071,29
2,-0.045600,-0.157905,-0.050758,0.011302,0.021276,0.034158,-0.004956,-0.023113,-0.129093,-0.083826,0.002504,-0.027168,0.036097,0.144366,-0.016244,0.007141,0.136660,-0.113312,0.027702,-0.008820,-0.073748,0.050143,-0.019548,-0.005121,-0.010675,-0.067167,-0.042906,0.026487,-0.070859,0.086989,0.050669,0.010482,0.003611,0.098422,-0.114743,-0.000294,-0.041523,0.021876,-0.025640,0.055786,...,-0.035144,-0.037522,-0.013612,-0.071892,-0.063444,0.078394,0.007223,0.081667,-0.024427,0.144901,0.001315,-0.088839,0.046962,0.102085,-0.046889,-0.034753,0.110942,0.060431,0.054377,0.048173,-0.018093,0.089450,0.026224,-0.035778,0.034848,-0.010152,0.048614,0.029363,-0.065129,-0.027129,0.042591,0.024180,0.066343,0.048486,0.015461,-0.019211,-0.004312,-0.001787,0.106549,0
3,-0.038647,-0.154763,-0.039952,0.001397,0.031636,0.039921,-0.001247,-0.024302,-0.128230,-0.086665,0.015649,-0.024580,0.037803,0.123708,-0.017627,0.007359,0.137713,-0.093353,0.021279,-0.021718,-0.074458,0.026678,-0.005879,-0.013708,0.000881,-0.073184,-0.027376,0.042387,-0.066450,0.064248,0.042997,-0.002864,-0.001121,0.076232,-0.118771,-0.033312,-0.006186,0.010576,-0.029132,0.040193,...,-0.038533,-0.028387,0.003016,-0.067833,-0.053557,0.059573,0.013956,0.063686,-0.054074,0.112275,0.018346,-0.101084,0.043854,0.104195,-0.058157,-0.032039,0.092432,0.064139,0.040732,0.053960,-0.018155,0.108370,0.022140,-0.033278,0.029228,-0.008480,0.031524,0.036247,-0.060980,-0.024532,0.047524,0.014547,0.059441,0.070182,0.019596,-0.013347,-0.009868,-0.024743,0.100561,11
4,-0.021916,-0.154215,-0.056060,-0.014020,0.062271,0.037410,0.019408,-0.011449,-0.115691,-0.069562,-0.001713,-0.046327,0.022796,0.127758,-0.009192,-0.003460,0.121133,-0.093321,0.003214,-0.028389,-0.062069,0.018834,0.004040,-0.017026,-0.012003,-0.064693,-0.014599,0.039642,-0.038787,0.068781,0.038381,-0.000952,-0.020432,0.103773,-0.104720,-0.032203,0.006418,0.023234,-0.038862,0.038389,...,-0.015888,-0.013812,-0.013026,-0.064356,-0.065923,0.061883,0.026700,0.064902,-0.074889,0.080298,0.026037,-0.067241,0.070075,0.099922,-0.057702,-0.040342,0.069343,0.021502,0.061742,0.054532,-0.016568,0.098932,-0.003447,-0.020230,0.033434,-0.031429,0.002442,0.028950,-0.044336,-0.002090,0.028610,0.020365,0.046056,0.068335,-0.006521,-0.005404,-0.007789,-0.014325,0.091656,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
49995,-0.039744,-0.138452,-0.048519,-0.013594,0.059558,0.042930,0.017416,-0.009105,-0.109466,-0.072587,0.000673,-0.046763,0.012116,0.128881,-0.003099,0.000403,0.106181,-0.092451,0.018649,-0.034738,-0.069772,0.026436,0.009395,-0.020877,-0.003191,-0.063882,-0.018499,0.046297,-0.043705,0.047268,0.019301,0.003820,-0.005984,0.097720,-0.107430,-0.028901,-0.000580,0.014745,-0.034962,0.047961,...,-0.008858,-0.020654,0.008338,-0.049711,-0.069257,0.056829,0.026461,0.042914,-0.067564,0.064707,0.031730,-0.065571,0.063366,0.095657,-0.054677,-0.037208,0.060044,0.010380,0.051916,0.051688,-0.011726,0.102404,0.003680,-0.020353,0.036301,-0.023288,-0.008478,0.012911,-0.040264,0.000421,0.039822,0.008751,0.047736,0.073547,-0.016875,-0.012376,-0.012272,-0.022055,0.094461,80
49996,-0.033311,-0.154913,-0.048542,0.019029,0.026837,0.039464,-0.007161,-0.018951,-0.127134,-0.073823,0.003997,-0.024385,0.035948,0.135793,-0.018350,0.010917,0.135912,-0.116252,0.029209,-0.010293,-0.066214,0.044366,-0.015401,-0.005268,-0.001001,-0.066109,-0.033078,0.032191,-0.059376,0.073745,0.048909,0.000453,-0.000066,0.088596,-0.108637,-0.003061,-0.024914,0.022549,-0.027748,0.048321,...,-0.032820,-0.034885,-0.010590,-0.069465,-0.055602,0.070492,0.013048,0.080791,-0.031355,0.126725,0.011706,-0.084333,0.044883,0.103006,-0.042772,-0.040560,0.096497,0.059455,0.052952,0.051636,-0.017026,0.088952,0.019078,-0.031264,0.043284,-0.018128,0.044918,0.033318,-0.060332,-0.020186,0.037701,0.030712,0.059983,0.061370,0.021487,-0.013252,-0.010820,-0.014747,0.109797,7
49997,-0.037295,-0.148652,-0.039271,0.002686,0.047185,0.051718,0.014766,-0.030629,-0.145150,-0.080993,0.004015,-0.045334,0.029302,0.142307,-0.020600,0.007400,0.108605,-0.109144,0.013952,-0.016163,-0.088131,0.050289,-0.006318,-0.021709,-0.007634,-0.047988,-0.049550,0.028302,-0.050328,0.050456,0.042161,0.004516,-0.003279,0.100406,-0.122801,-0.018862,-0.014005,0.018094,-0.016393,0.053565,...,-0.028222,-0.037659,-0.009726,-0.067386,-0.062971,0.069525,0.015013,0.068891,-0.026781,0.122626,0.009484,-0.070765,0.049979,0.101772,-0.053865,-0.036877,0.088673,0.050380,0.036127,0.056887,-0.015859,0.102470,0.026646,-0.044439,0.041413,-0.002694,0.039060,0.029949,-0.058964,-0.040449,0.050895,0.020167,0.052196,0.056272,0.015601,-0.003301,0.004039,-0.017802,0.109519,3
49998,-0.026447,-0.148029,-0.040077,0.023344,0.051278,0.053717,0.003016,-0.020872,-0.150415,-0.064110,0.011028,-0.034608,0.034324,0.132357,-0.025379,0.006460,0.114737,-0.122587,0.029596,-0.022679,-0.080963,0.039890,-0.006390,-0.016500,0.009919,-0.048911,-0.045991,0.051267,-0.032016,0.040193,0.031466,-0.016007,0.002310,0.091783,-0.120539,-0.010394,0.006364,0.025162,-0.032405,0.040319,...,-0.028109,-0.030094,-0.013604,-0.064137,-0.040589,0.062103,0.027986,0.062913,-0.040247,0.108458,0.017158,-0.059244,0.056143,0.109150,-0.036355,-0.033298,0.079851,0.037137,0.039292,0.053457,-0.012578,0.099699,0.010465,-0.041619,0.059859,-0.017220,0.032988,0.043940,-0.049764,-0.022031,0.050677,0.021247,0.045451,0.068254,0.017613,0.004955,-0.004342,-0.030023,0.123722,7


In [65]:
train_feat_df.to_csv('CIFAR_100_TRAIN_FEAT100VGG16',index=False)