In [99]:
import datetime
import pandas as pd
from typing import List

In [100]:
xls_path = "/content/c17_separate_groups.xlsx"

In [101]:
class QueryImage:
    def __init__(self, base_path: str, index_name: str, xls_path: str):
        self.base_path = base_path
        self.index_name = index_name
        self.xls_path = xls_path
        self.xls_data = pd.read_excel(self.xls_path)

    def generate_image_path(self, date: str, pred_type: str, img_type: str) -> str:
        # Parse date in dd-mm-yy format
        day, month, year = date.split("-")
        # Convert month from string to integer and get month name abbreviation
        month_name = f"{month.zfill(2)}_{datetime.date(2000, int(month), 1).strftime('%B')}"
        # Generate image path based on the given parameters
        if img_type == "correction":
            image_path = f"{self.base_path}/correction/{self.index_name}/{year}/{month_name}/{day}_{month}_"
        elif img_type == "impulse":
            image_path = f"{self.base_path}/impulse/{self.index_name}/{year}/{month_name}/{day}_{month}_"
        else:
            raise ValueError("Invalid image type. Must be either 'correction' or 'impulse'")

        if pred_type == "1d_eb":
            image_path += "1d_eb_15-25.png"
        elif pred_type == "c17_eb":
            image_path += "c17_eb_15-25.png"
        elif pred_type == "c17_nb":
            image_path += "c17_nb_15-25.png"
        else:
            raise ValueError("Invalid prediction type. Must be either '1d_eb', 'c17_eb' or 'c17_nb'")

        return image_path

    def get_label_tuple(self, date: str, pred_type: str, img_type: str) -> tuple:
        image_path = self.generate_image_path(date, pred_type, img_type)
        df = self.xls_data.loc[self.xls_data["image_path"] == image_path]
        if len(df) == 0:
            if img_type == "correction":
                new_img_type = "impulse"
            elif img_type == "impulse":
                new_img_type = "correction"
            else:
                raise ValueError("Invalid image type. Must be either 'correction' or 'impulse'")
            message = f"No matching row found for {img_type} in the xls file. Try changing the img_type to {new_img_type}."
            raise ValueError(message)
        else:
            prev_pred_value = df.at[df.index[0], 'prev_pred']
            c17_eb_pred_value = df.at[df.index[0], 'c17_eb_pred']
            c17_nb_pred_value = df.at[df.index[0], 'c17_nb_pred']
            return prev_pred_value, c17_eb_pred_value, c17_nb_pred_value

    def get_matching_image_paths(self, date: str, pred_type: str, img_type: str) -> List[str]:
        query_image_path = self.generate_image_path(date, pred_type, img_type)
        query_label_tuple = self.get_label_tuple(date, pred_type, img_type)
        query_df = self.xls_data.loc[(self.xls_data["image_path"] == query_image_path) & (self.xls_data["prev_pred"] == query_label_tuple[0]) & (self.xls_data["c17_eb_pred"] == query_label_tuple[1]) & (self.xls_data["c17_nb_pred"] == query_label_tuple[2])]
        
        if len(query_df) == 0:
            raise ValueError("No matching row found for the given query.")
        
        matching_rows_df = self.xls_data.loc[(self.xls_data["prev_pred"] == query_label_tuple[0]) & (self.xls_data["c17_eb_pred"] == query_label_tuple[1]) & (self.xls_data["c17_nb_pred"] == query_label_tuple[2])]
        matching_image_paths = matching_rows_df.loc[matching_rows_df["image_path"] != query_image_path, "image_path"].tolist()
        
        return matching_image_paths


In [102]:
query = QueryImage("./data/stock_indices_data", "NIFTY50", xls_path)
image_path = query.generate_image_path("24-03-2020", "1d_eb", "impulse")
print(image_path)  # Output: ./data/stock_indices_data/impulse/NIFTY50/2020/03_March/24_03_1d_eb_15-25.png


./data/stock_indices_data/impulse/NIFTY50/2020/03_March/24_03_1d_eb_15-25.png


In [103]:
image_path = query.generate_image_path("09-11-2022", "1d_eb", "correction")
print(image_path)  


./data/stock_indices_data/correction/NIFTY50/2022/11_November/09_11_1d_eb_15-25.png


In [104]:
label_tuple = query.get_label_tuple("09-11-2022", "1d_eb", "correction")
print(label_tuple)

('impulse', 'impulse', 'impulse')


In [105]:
matching_image_paths = query.get_matching_image_paths("09-11-2022", "1d_eb", "correction")
for x in matching_image_paths:
    print(x)

./data/stock_indices_data/impulse/NIFTY50/2020/10_October/05_10_1d_eb_15-25.png
./data/stock_indices_data/impulse/NIFTY50/2020/10_October/06_10_1d_eb_15-25.png
./data/stock_indices_data/impulse/NIFTY50/2020/10_October/07_10_1d_eb_15-25.png
./data/stock_indices_data/impulse/NIFTY50/2020/10_October/08_10_1d_eb_15-25.png
./data/stock_indices_data/impulse/NIFTY50/2020/10_October/09_10_1d_eb_15-25.png
./data/stock_indices_data/impulse/NIFTY50/2020/11_November/02_11_1d_eb_15-25.png
./data/stock_indices_data/impulse/NIFTY50/2020/11_November/06_11_1d_eb_15-25.png
./data/stock_indices_data/impulse/NIFTY50/2020/11_November/09_11_1d_eb_15-25.png
./data/stock_indices_data/impulse/NIFTY50/2020/11_November/10_11_1d_eb_15-25.png
./data/stock_indices_data/impulse/NIFTY50/2020/11_November/17_11_1d_eb_15-25.png
./data/stock_indices_data/impulse/NIFTY50/2020/11_November/23_11_1d_eb_15-25.png
./data/stock_indices_data/correction/NIFTY50/2020/11_November/25_11_1d_eb_15-25.png
./data/stock_indices_data/impu