<a href="https://colab.research.google.com/github/mehrdadrashidian/CG-Ecosystem/blob/master/Recommendation_System.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from flask import Flask, request, jsonify
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import os
from sentence_transformers import SentenceTransformer
import librosa
from torchvision import models, transforms
from PIL import Image
import torch
from pymilvus import connections, Collection, utility
import decord
from transformers import TimesformerModel, TimesformerConfig

# Initialize Flask app
app = Flask(__name__)

# Initialize BERT model for text
text_model = SentenceTransformer('all-MiniLM-L6-v2')

# Initialize ResNet model for images
image_model = models.resnet50(pretrained=True)
image_model.eval()
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Initialize TimeSformer model for video
video_model_config = TimesformerConfig.from_pretrained('facebook/timesformer-base-finetuned-k400')
video_model = TimesformerModel.from_pretrained('facebook/timesformer-base-finetuned-k400')

# Milvus configuration
connections.connect("default", host="localhost", port="19530")
collection_name = "asset_vectors"
if not utility.has_collection(collection_name):
    collection = Collection(collection_name, {
        "fields": [
            {"name": "asset_id", "type": "VARCHAR", "max_length": 255},
            {"name": "vector", "type": "FLOAT_VECTOR", "dim": 768},
            {"name": "asset_type", "type": "VARCHAR", "max_length": 50},
            {"name": "tags", "type": "VARCHAR", "max_length": 255}
        ],
        "primary_field": "asset_id"
    })
    collection.create_index(field_name="vector", index_params={"metric_type": "IP", "index_type": "IVF_FLAT", "params": {"nlist": 128}})
else:
    collection = Collection(collection_name)

# Function to extract features based on Asset Type
def extract_features(asset_type, file_path):
    if asset_type == "text":
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
        return text_model.encode(text)
    elif asset_type == "audio":
        y, sr = librosa.load(file_path, sr=None)
        mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=40)
        return np.mean(mfcc.T, axis=0)
    elif asset_type == "image":
        img = Image.open(file_path).convert('RGB')
        img = image_transform(img).unsqueeze(0)
        with torch.no_grad():
            features = image_model(img).squeeze(0).numpy()
        return features
    elif asset_type == "video":
        vr = decord.VideoReader(file_path)
        frames = [vr[i].asnumpy() for i in range(0, len(vr), len(vr)//8)]  # Sample 8 frames
        processed_frames = torch.stack([
            image_transform(Image.fromarray(frame)) for frame in frames
        ]).unsqueeze(0)  # Batch size of 1
        with torch.no_grad():
            features = video_model(pixel_values=processed_frames).last_hidden_state.mean(dim=1).squeeze(0).numpy()
        return features
    else:
        raise ValueError("Unsupported asset type")

# Normalize and store vectors in Milvus
def store_asset(asset_id, asset_type, file_path, tags):
    vector = extract_features(asset_type, file_path)
    collection.insert([[asset_id, vector.tolist(), asset_type, ",".join(tags)]])

# Endpoint to upload an asset
@app.route('/upload', methods=['POST'])
def upload_asset():
    try:
        data = request.json
        asset_id = data['asset_id']
        asset_type = data['asset_type']
        file_path = data['file_path']
        tags = data.get('tags', [])

        if asset_type not in ['text', 'audio', 'image', 'video']:
            return jsonify({"error": "Invalid asset type"}), 400

        store_asset(asset_id, asset_type, file_path, tags)

        return jsonify({"message": "Asset uploaded successfully"}), 200
    except Exception as e:
        return jsonify({"error": str(e)}), 500

# Endpoint to get recommendations based on an asset
@app.route('/recommend', methods=['POST'])
def recommend():
    try:
        data = request.json
        asset_id = data['asset_id']
        top_n = data.get('top_n', 5)

        query = collection.query(expr=f"asset_id == '{asset_id}'", output_fields=["vector"])
        if not query:
            return jsonify({"error": "Asset ID not found"}), 404

        query_vector = np.array(query[0]['vector']).reshape(1, -1)
        search_params = {"metric_type": "IP", "params": {"nprobe": 10}}
        results = collection.search(query_vector.tolist(), "vector", search_params, limit=top_n, output_fields=["asset_id", "asset_type"])

        recommendations = [
            {"asset_id": hit.entity.get("asset_id"), "similarity": hit.distance}
            for hit in results
        ]

        return jsonify({"recommendations": recommendations}), 200
    except Exception as e:
        return jsonify({"error": str(e)}), 500

# Endpoint to list all assets
@app.route('/assets', methods=['GET'])
def list_assets():
    try:
        assets = collection.query(expr="", output_fields=["asset_id", "asset_type", "tags"])
        return jsonify(assets), 200
    except Exception as e:
        return jsonify({"error": str(e)}), 500

if __name__ == '__main__':
    app.run(debug=True)
