In [None]:
import socket
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import struct

clases_es = {
    "T-shirt/top": "polo",
    "Trouser": "pantalón",
    "Pullover": "suéter",
    "Dress": "vestido",
    "Coat": "abrigo",
    "Sandal": "sandalia",
    "Shirt": "camisa",
    "Sneaker": "zapatilla",
    "Bag": "bolso",
    "Ankle boot": "botín"
}
class_names = list(clases_es.keys())

# CNN clase
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
    def forward(self, x):
        return self.net(x)

# cargando a fashionista
model = CNN()
model.load_state_dict(torch.load("fashionista.pth", map_location=torch.device('cpu')))
model.eval()

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# TCP server
server = socket.socket()
server.bind(("0.0.0.0", 5050))
server.listen(1)
print("Esperando conexión de NAO...")

while True:
    conn, addr = server.accept()
    print("Conectado desde: {}".format(addr))

    try:
        # get size
        size_data = conn.recv(4)
        if not size_data:
            raise ValueError("No se recibió tamaño de imagen")
        size = struct.unpack('>I', size_data)[0]

        # get img
        data = b""
        while len(data) < size:
            packet = conn.recv(size - len(data))
            if not packet:
                break
            data += packet

        if len(data) < size:
            raise ValueError("Datos de imagen incompletos")

    
        image = Image.frombytes('RGB', (640, 480), data)
        image = transform(image).unsqueeze(0)  # (1, 1, 28, 28)

        with torch.no_grad():
            output = model(image)
            _, pred = torch.max(output, 1)
            clase_en = class_names[pred.item()]  # ejemplo: "Sneaker"

        print("Predicción:", clase_en)
        conn.sendall(clase_en.encode("utf-8"))
        print("Predicción enviada:", clase_en)


    except Exception as e:
        print("Error:", e)
        conn.sendall("unknown".encode("utf-8"))

    finally:
        conn.close()


  model.load_state_dict(torch.load("fashionista.pth", map_location=torch.device('cpu')))


Esperando conexión de NAO...
Conectado desde: ('192.168.108.235', 57497)
Predicción: Bag
Predicción enviada: Bag
Conectado desde: ('192.168.108.235', 57503)
Predicción: Bag
Predicción enviada: Bag
