In [29]:
# PlacesCNN for scene classification
#
# by Bolei Zhou
# last modified by Bolei Zhou, Dec.27, 2017 with latest pytorch and torchvision (upgrade your torchvision please if there is trn.Resize error)
!pip install torch
!pip install torchvision

import torch
from torch.autograd import Variable as V
import torchvision.models as models
from torchvision import transforms as trn
from torch.nn import functional as F
import os
from PIL import Image
import numpy as np
import glob
import tqdm
import pandas
import boto3
import io
import matplotlib.pyplot as plt 
import pandas as pd

You should consider upgrading via the '/home/ec2-user/anaconda3/envs/amazonei_mxnet_p36/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/amazonei_mxnet_p36/bin/python -m pip install --upgrade pip' command.[0m


In [51]:
# download data from S3
import sagemaker
from sagemaker import get_execution_role
 
role = get_execution_role()
print(role)
sess = sagemaker.Session()

# S3 settings
bucket='gsv-beauty-score-test-100'
if not os.path.isdir('../img_data'):
    os.makedirs('../img_data', exist_ok=True)
sess.download_data('../img_data/', bucket, extra_args=None)

arn:aws:iam::428024436188:role/service-role/AmazonSageMaker-ExecutionRole-20210205T145626


In [63]:
# the architecture to use
arch = 'resnet50'

# load the pre-trained weights
model_file = '{}_places365.pth.tar'.format(arch)
if not os.access(model_file, os.W_OK):
    weight_url = 'http://places2.csail.mit.edu/models_places365/' + model_file
    os.system('wget ' + weight_url)

model = models.__dict__[arch](num_classes=365)
checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
model.load_state_dict(state_dict)
model.eval()


# load the image transformer
centre_crop = trn.Compose([
        trn.Resize((256,256)),
        trn.CenterCrop(224),
        trn.ToTensor(),
        trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# load the class label
file_name = 'categories_places365.txt'
if not os.access(file_name, os.W_OK):
    synset_url = 'https://raw.githubusercontent.com/csailvision/places365/master/categories_places365.txt'
    os.system('wget ' + synset_url)
classes = list()
with open(file_name) as class_file:
    for line in class_file:
        classes.append(line.strip().split(' ')[0][3:])
classes = tuple(classes)

# result dictionary
result_dict=dict()

# run through img list

for img_name in tqdm.tqdm(os.listdir('../img_data')):
    # open img
    img = Image.open(os.path.join('../img_data',img_name))
    input_img = V(centre_crop(img).unsqueeze(0))

    # forward pass
    logit = model.forward(input_img)
    h_x = F.softmax(logit, 1).data.squeeze()
    probs, idx = h_x.sort(0, True)

    # convert to np array
    probs_array=probs.numpy()
    classes_list=[classes[idx[i]] for i in range(len(idx))]
    classes_array=np.array(classes_list)

    # store the result in dictionaries
    temp_dict=dict(zip(classes_array,probs_array))
    result_dict[img_name.split('/')[-1].replace('.jpg','')]= temp_dict

# convert dict to df
result_df=pd.DataFrame.from_dict({(i): result_dict[i]
                          for i in result_dict.keys()},
                       orient='index',
                      columns=list(classes))

if not os.path.isdir('../tabular_data'):
    os.makedirs('../tabular_data', exist_ok=True)
result_df.to_csv('../tabular_data/classification.csv')
sess.upload_data(path='../tabular_data/classification.csv', bucket='tabular-data-bikeability', key_prefix='classification')




  0%|          | 0/119 [00:00<?, ?it/s][A[A[A


  1%|          | 1/119 [00:00<00:21,  5.39it/s][A[A[A


  2%|▏         | 2/119 [00:00<00:22,  5.27it/s][A[A[A


  3%|▎         | 3/119 [00:00<00:21,  5.42it/s][A[A[A


  3%|▎         | 4/119 [00:00<00:20,  5.55it/s][A[A[A


  4%|▍         | 5/119 [00:00<00:20,  5.59it/s][A[A[A


  5%|▌         | 6/119 [00:01<00:20,  5.65it/s][A[A[A


  6%|▌         | 7/119 [00:01<00:19,  5.71it/s][A[A[A


  7%|▋         | 8/119 [00:01<00:19,  5.79it/s][A[A[A


  8%|▊         | 9/119 [00:01<00:18,  5.85it/s][A[A[A


  8%|▊         | 10/119 [00:01<00:18,  5.88it/s][A[A[A


  9%|▉         | 11/119 [00:01<00:18,  5.72it/s][A[A[A


 10%|█         | 12/119 [00:02<00:18,  5.72it/s][A[A[A


 11%|█         | 13/119 [00:02<00:18,  5.73it/s][A[A[A


 12%|█▏        | 14/119 [00:02<00:18,  5.69it/s][A[A[A


 13%|█▎        | 15/119 [00:02<00:18,  5.71it/s][A[A[A


 13%|█▎        | 16/119 [00:02<00:18,  5.72it/s][A[A

's3://tabular-data-bikeability/classification/classification.csv'