In [None]:
from PIL import Image
import requests
from io import BytesIO
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
import copy

# Function to load the image from a local file
def load_image(image_path):
    return Image.open(image_path)

# Example usage:
image_path = 'YOUR IMAGE PATH.jpg'
image = load_image(image_path)
image.show()  # This will display the image

# Load the BLIP model and processor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

def generate_description_old(image):
    # Preprocess the image
    inputs = processor(images=image, return_tensors="pt")
    
    # Generate a description
    outputs = model.generate(**inputs)
    description = processor.decode(outputs[0], skip_special_tokens=True)
    
    return description

# Example usage:
description = generate_description_old(image)
print("Generated Description:", description)

import torch
import matplotlib
import numpy as np
from datasets import load_dataset, DatasetDict
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration, TrainingArguments, Trainer
#from matplotlib import c

fashion_mnist=load_dataset('fashion_mnist', download_mode='force_redownload', cache_dir=None)

# # Load the full FashionMNIST train and test datasets if you want to subset from them
#train_dataset = load_dataset('fashion_mnist', split='train', download_mode='force_redownload', cache_dir=None)
#test_dataset = load_dataset('fashion_mnist', split='test', download_mode='force_redownload', cache_dir=None)

# Select the first 30,000 rows from the training set
#train_subset = train_dataset.select(range(10000))

# Select the first 5,000 rows from the test set
#test_subset = test_dataset.select(range(500))


# fashion_mnist = DatasetDict({
#     'train': train_subset,               # Assign the train subset
#     'test': test_subset                # Assign the test subset
#     #'validation': train_subset.select(range(5000))  # Optionally create a validation set
# })



# Load the BLIP processor and model
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")


#clear cache before training the model
import torch
torch.cuda.empty_cache() 

# Load the Fashion-MNIST dataset
import torch
import PIL
from torchvision import transforms
from transformers import AutoTokenizer


# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained("Salesforce/blip-image-captioning-base")  # Replace with the model you are usin


def ResizeMax(img, max_sz=256):
    if isinstance(img, Image.Image):
        img.thumbnail((max_sz, max_sz), Image.ANTIALIAS)
        return img
    else:
        raise ValueError("Expected a PIL image")
        
# Function to preprocess images and labels from Fashion-MNIST
def preprocess_fashion_mnist(examples):

    transform = transforms.Compose([
        transforms.Lambda(lambda x: Image.fromarray(np.uint8(x)).convert('RGB')), # Convert numpy array to PIL Image and to RGB
        transforms.Resize((256, 256)), 
        transforms.PILToTensor()          # Convert PIL Image to Tensor
    ])

    # Apply the transformation directly
    images = [transform(image) for image in examples['image']]

    #Generate text
    text_data = ["A picture of a clothing item." for _ in range(len(examples['image']))]

    # Process images using the BLIP processor
    inputs = processor(images=images, text=text_data, return_tensors="pt", padding=True, truncation=True)
    
    # Ensure pixel_values is a tensor
    pixel_values = inputs['pixel_values']

    if isinstance(pixel_values, list):
        pixel_values = torch.stack([torch.tensor(img).permute(1, 2, 0) for img in pixel_values])


    input_ids = inputs['input_ids']

    return {
        'pixel_values': pixel_values,
        'labels': torch.tensor(examples['label']),
        'input_ids': input_ids  # Remove batch dimension if needed

    }

# Apply preprocessing to the dataset
fashion_mnist = fashion_mnist.map(preprocess_fashion_mnist, batched=True, batch_size=8)


# Create a deep copy of the DatasetDict to then set list back to tensor
formatted_fashion_mnist = copy.deepcopy(fashion_mnist)
formatted_fashion_mnist.set_format(type='pt', columns=['pixel_values', 'label','image'], output_all_columns=True)
print(fashion_mnist)
print(formatted_fashion_mnist)

#Do some checks on the output 
from torchvision import transforms

print(type(formatted_fashion_mnist))

# Access a sample from the 'train' split
first_sample = formatted_fashion_mnist['train'][435]

print(f"First sample type: {type(first_sample)}")
print(f"First sample keys: {list(first_sample.keys())}")

# # Check the type and content of pixel_values and labels
pixel_values = first_sample['pixel_values']
labels = first_sample['labels']
label = first_sample['label']
image=first_sample['image']

#print(f"Pixels type: {type(pixel_values)}")
#print(f"Labels type: {type(labels)}")
#print(f"Label type: {type(label)}")
#print(f"Image is: {image}")



# Convert the pixels into image

def display_image_from_pixel_values(pixel_values):
    # Extract the first image from the batch
    if pixel_values.ndim == 4:
        pixel_values = pixel_values[0]  # Get the first image in the batch

    # Check if pixel_values is normalized
    if pixel_values.max() <= 1.0:
        # Assuming values are normalized between 0 and 1
        pixel_values = pixel_values * 255.0
    
    # Convert tensor to NumPy array and ensure values are in the 0-255 range
    image_np = pixel_values.permute(1, 2, 0).byte().numpy()
    
    # Create a PIL Image from the NumPy array
    image_pil = Image.fromarray(image_np)
    
    # Display the image
    image_pil.show()

# Example usage
# Assuming `formatted_fashion_mnist` has been set up correctly
first_sample = formatted_fashion_mnist['train'][5000]
pixel_values = first_sample['pixel_values']

# Display the image
display_image_from_pixel_values(pixel_values)

from transformers import TrainingArguments, Trainer

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        print('Computing loss...')
        
        # Extract pixel_values, input_ids, and labels
        pixel_values = inputs.get('pixel_values')
        input_ids = inputs.get('input_ids')
        labels = inputs.get('input_ids')

        # Debug statements to check types and shapes
        print(f"Pixel values type: {type(pixel_values)}, shape: {pixel_values.shape if pixel_values is not None else 'None'}")
        print(f"Input IDs type: {type(input_ids)}, shape: {input_ids.shape if input_ids is not None else 'None'}")
        print(f"Labels type: {type(labels)}, shape: {labels.shape if labels is not None else 'None'}")
        
        # Check if pixel_values, input_ids, and labels are None
        if pixel_values is None or input_ids is None or labels is None:
            raise ValueError("Pixel values, input ids, or labels are None")
        
        # Ensure the model is receiving the correct inputs
        outputs = model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            labels=labels
        )
        
        # Extract loss
        loss = outputs.loss
        return (loss, outputs) if return_outputs else loss


# Define training arguments
training_args = TrainingArguments(
     output_dir="./blip-finetuned-fashion-mnist",
     bf16=True, #Enable mixed precision training (fp16)
     per_device_train_batch_size=12,
     #gradient_accumulation_steps=2,  # Accumulate gradients over steps to simulate a larger batch size
     eval_strategy="steps",
  #   dataloader_num_workers=4,  # makes it slower for now
     num_train_epochs=3,
     save_steps=1000,
     save_total_limit=2,
     #remove_unused_columns=True,
     remove_unused_columns=False,
    logging_dir="./logs",
)

# Initialize Trainer
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=formatted_fashion_mnist['train'],
    eval_dataset=formatted_fashion_mnist['test'],
)

# Start training
trainer.train()

from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import torch

# Load the fine-tuned BLIP model and processor
model_path = "./blip-finetuned-fashion-mnist"  # Path to the fine-tuned model
processor = BlipProcessor.from_pretrained(model_path)
model = BlipForConditionalGeneration.from_pretrained(model_path)


# Load processor and model from the final model directory (root)
processor = BlipProcessor.from_pretrained("./blip-finetuned-fashion-mnist")
model = BlipForConditionalGeneration.from_pretrained("./blip-finetuned-fashion-mnist")
# Set the model to evaluation mode
model.eval()

# Function to load and process an image
def load_and_preprocess_image(image_path):
    # Open the image using PIL
    image = Image.open(image_path).convert("RGB")
    
    # Preprocess the image with the processor
    inputs = processor(images=image, return_tensors="pt")

    return inputs['pixel_values']

# Function to generate description for an image
def generate_description(image_path):
    # Load and preprocess the image
    pixel_values = load_and_preprocess_image(image_path)

    # Generate description (caption)
    with torch.no_grad():
        generated_ids = model.generate(pixel_values)

    # Decode the generated tokens to text
    description = processor.decode(generated_ids[0], skip_special_tokens=True)

    return description

# Example usage
image_path = "PATH_TO_IMAGE.jpg"  # Replace with the path to the image you want to describe

description = generate_description(image_path)
print("Generated Description:", description)

#EXTRA CODE TO EXTRACT ATTRIBUTES TOO

from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import torch

import spacy
import webcolors
from collections import Counter

# Load the spaCy NLP model
nlp = spacy.load("en_core_web_sm")

# Load the fine-tuned BLIP model and processor
model_path = "./blip-finetuned-fashion-mnist"  # Path to the fine-tuned model
processor = BlipProcessor.from_pretrained(model_path)
model = BlipForConditionalGeneration.from_pretrained(model_path)


# Function to generate a description from an image using BLIP
def generate_description(image_path):
    # Load and preprocess the image
    raw_image = Image.open(image_path).convert('RGB')
    
    # Prepare inputs for the BLIP model
    inputs = processor(raw_image, return_tensors="pt")
    
    # Generate the description using the BLIP model
    outputs = model.generate(**inputs)
    description = processor.decode(outputs[0], skip_special_tokens=True)
    
    return description

# Function to dynamically extract colors using webcolors
def extract_colors(description):
    # Common colors that can be extracted dynamically (using CSS21 standard)
    color_names = set(webcolors.names("css3"))   

    # Find any color matches in the description
    words = description.lower().split()
    detected_colors = [word for word in words if word in color_names]
    # If no specific colors found, return None
    return ", ".join(set(detected_colors)) if detected_colors else None

# Function to dynamically extract possible styles, fits, and other attributes using NLP
def extract_attributes(description):
    # Process the description with spaCy NLP
    doc = nlp(description.lower())

    attributes = {
        "Style": None,
        "Color": None,
        "Fit": None,
        "Season": None,
        "Brand": None
    }

    # Extract adjectives and nouns which often describe style and fit
    adjectives = [token.text for token in doc if token.pos_ == 'ADJ']
    nouns = [token.text for token in doc if token.pos_ == 'NOUN']

    # Combine adjectives and nouns as they are typically descriptive of the clothing
    possible_styles = adjectives + nouns

    # Set style (all adjectives/nouns can be considered style-related)
    if possible_styles:
        attributes["Style"] = ", ".join(set(possible_styles))

    # Extract color dynamically using the webcolors library
    attributes["Color"] = extract_colors(description)

    # Fit extraction (search for common fitting-related words)
    fit_keywords = ["fitted", "loose", "tailored", "flared", "relaxed", "oversized"]
    detected_fits = [word for word in adjectives if word in fit_keywords]
    if detected_fits:
        attributes["Fit"] = ", ".join(set(detected_fits))

    # Season extraction based on common season words
    if any(season in description.lower() for season in ["spring", "summer"]):
        attributes["Season"] = "Spring, Summer"
    elif any(season in description.lower() for season in ["fall", "winter"]):
        attributes["Season"] = "Fall, Winter"

    # Brand extraction - assumes brand is not mentioned unless explicitly found
    attributes["Brand"] = "Specific brand not identified"

    return attributes

# Example usage
image_path = "PATH_TO_IMAGE.jpg"
description = generate_description(image_path)
attributes = extract_attributes(description)

print("Description:", description)
print("Attributes:", attributes)