In [6]:
from torchvision import transforms
import torchvision
import torch
import torch.nn as nn
import glob
import cv2
from numpy import argmax
from pymongo import MongoClient

In [7]:
try:
	# Conectar a la db, host y puerto
	conn = MongoClient(host='localhost', port=27017)
	# Obtener base de datos
	db = conn.local
except:
	pass

In [8]:
# Definir modelo
class scratch_nn(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=100, kernel_size=5, stride=1, padding=0)
        self.conv2 = nn.Conv2d(100, 200, 3, stride=1, padding=0)
        self.conv3 = nn.Conv2d(200, 400, 3, stride=1, padding=0)
        self.mpool = nn.MaxPool2d(kernel_size=3)
        self.relu = nn.ReLU()
        self.linear1 = nn.Linear(19600,1024)
        self.linear2 = nn.Linear(1024,512)
        self.linear3 = nn.Linear(512,2)
        self.classifier = nn.Softmax(dim=1)
        
    def forward(self,x):
        x = self.mpool( self.relu(self.conv1(x)) )
        x = self.mpool( self.relu(self.conv2(x)) )
        x = self.mpool( self.relu(self.conv3(x)) )
        x = torch.flatten(x, start_dim=1)
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.linear3(x)
        x = self.classifier(x)
        return x

In [9]:

# Cargar modelo entrenado
device = torch.device('cpu')
model = scratch_nn()
model.load_state_dict(torch.load("dogs_cats_model.pth"))
model.eval()
model = model.to(device)

# Definir preprocesados de la imagen
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
])


In [None]:
# Realizar la prediccion de todas las imagenes en la carpeta
labels = ["Cat", "Dog"]
for image_path in glob.glob("predict_cat_dog/*.jpg"):
	img_orig = cv2.imread(image_path)
	img = data_transform(img_orig).unsqueeze(0).to(device)
	outputs = model(img)
	outputs = outputs.detach().cpu().numpy()
	output = argmax(outputs, axis=1)[0]
	print("Predicted label: "+labels[output])
	cv2.imshow("Predicted label: "+labels[output], img_orig)
	cv2.waitKey(0)
	cv2.destroyAllWindows()
	# Almacenar en base de datos
	try:
		db.data.insert_one({"path_img": image_path, "predicted_label": labels[output]})
	except:
		pass

Predicted label: Dog


In [2]:
pip install pymongo

Collecting pymongo
  Downloading pymongo-4.3.3-cp38-cp38-macosx_10_9_x86_64.whl (381 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m381.9/381.9 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dnspython<3.0.0,>=1.16.0
  Downloading dnspython-2.3.0-py3-none-any.whl (283 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m283.7/283.7 kB[0m [31m35.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: dnspython, pymongo
Successfully installed dnspython-2.3.0 pymongo-4.3.3
Note: you may need to restart the kernel to use updated packages.
