In [1]:
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
import os

# Define file paths
file_paths = [
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/机器人.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/半导体.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/半导体材料.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/半导体设备.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/人工智能.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/计算机.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/软件开发.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/电子.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/自动化设备.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/通信设备.csv'
]

# Create output directory if it doesn't exist
output_dir = '/Users/ash/Desktop/毕业/writer/output/AI研报/image'
os.makedirs(output_dir, exist_ok=True)

# Load and preprocess each file
data_frames = []
for file_path in file_paths:
    # Read the CSV file
    df = pd.read_csv(file_path)
    
    # Make a copy of the DataFrame
    df_processed = df.copy()
    
    # Convert 'trade_date' to datetime format
    df_processed['trade_date'] = pd.to_datetime(df_processed['trade_date'], format='%Y%m%d')
    
    # Extract year, month, and day
    df_processed['year'] = df_processed['trade_date'].dt.year
    df_processed['month'] = df_processed['trade_date'].dt.month
    df_processed['day'] = df_processed['trade_date'].dt.day
    
    # Get the industry name from the filename
    industry_name = os.path.basename(file_path).replace('.csv', '')
    df_processed['industry'] = industry_name
    
    data_frames.append(df_processed)

# Combine all data into one DataFrame for plotting
combined_df = pd.concat(data_frames)

# Plotting setup
font_path = "/Library/Fonts/SimHei.ttf"
font_prop = FontProperties(fname=font_path)

fig, ax = plt.subplots(figsize=(12, 8))

# Plot each industry's total_mv over time
for industry, group in combined_df.groupby('industry'):
    ax.plot(group['trade_date'], group['total_mv'], label=industry)

# Formatting the plot
ax.set_xlabel('交易日期', fontproperties=font_prop)
ax.set_ylabel('总市值 (万元)', fontproperties=font_prop)
ax.set_title('申万行业指数总市值趋势', fontproperties=font_prop)
ax.legend(prop=font_prop)
plt.xticks(rotation=45)

# Save the plot
output_path = os.path.join(output_dir, '申万行业指数总市值趋势.png')
fig.savefig(output_path, bbox_inches='tight', dpi=300)
plt.close()


FileNotFoundError: [Errno 2] No such file or directory: '/Users/ash/Desktop/毕业/writer/data/申万行业指数/人工智能.csv'

In [2]:
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
import os

# Define file paths
file_paths = [
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/机器人.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/半导体.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/半导体材料.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/半导体设备.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/人工智能.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/计算机.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/软件开发.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/电子.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/自动化设备.csv',
    '/Users/ash/Desktop/毕业/writer/data/申万行业指数/通信设备.csv'
]

# Create output directory if it doesn't exist
output_dir = '/Users/ash/Desktop/毕业/writer/output/AI研报/image'
os.makedirs(output_dir, exist_ok=True)

# Load and preprocess each file
data_frames = []
for file_path in file_paths:
    try:
        # Check if file exists
        if not os.path.exists(file_path):
            print(f"File not found, skipping: {file_path}")
            continue
            
        # Read the CSV file
        df = pd.read_csv(file_path)
        
        # Make a copy of the DataFrame
        df_processed = df.copy()
        
        # Convert 'trade_date' to datetime format
        df_processed['trade_date'] = pd.to_datetime(df_processed['trade_date'], format='%Y%m%d')
        
        # Extract year, month, and day
        df_processed['year'] = df_processed['trade_date'].dt.year
        df_processed['month'] = df_processed['trade_date'].dt.month
        df_processed['day'] = df_processed['trade_date'].dt.day
        
        # Get the industry name from the filename
        industry_name = os.path.basename(file_path).replace('.csv', '')
        df_processed['industry'] = industry_name
        
        data_frames.append(df_processed)
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        continue

# Check if we have any data to plot
if not data_frames:
    print("No valid data files found to process")
else:
    # Combine all data into one DataFrame for plotting
    combined_df = pd.concat(data_frames)

    # Plotting setup
    font_path = "/Library/Fonts/SimHei.ttf"
    font_prop = FontProperties(fname=font_path)

    fig, ax = plt.subplots(figsize=(12, 8))

    # Plot each industry's total_mv over time
    for industry, group in combined_df.groupby('industry'):
        ax.plot(group['trade_date'], group['total_mv'], label=industry)

    # Formatting the plot
    ax.set_xlabel('交易日期', fontproperties=font_prop)
    ax.set_ylabel('总市值 (万元)', fontproperties=font_prop)
    ax.set_title('申万行业指数总市值趋势', fontproperties=font_prop)
    ax.legend(prop=font_prop)
    plt.xticks(rotation=45)

    # Save the plot
    output_path = os.path.join(output_dir, '申万行业指数总市值趋势.png')
    fig.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.close()
    print(f"Plot saved successfully to {output_path}")


File not found, skipping: /Users/ash/Desktop/毕业/writer/data/申万行业指数/人工智能.csv


Plot saved successfully to /Users/ash/Desktop/毕业/writer/output/AI研报/image/申万行业指数总市值趋势.png


In [3]:
# Enhance the existing plot with better visualization
font_path = "/Library/Fonts/SimHei.ttf"
font_prop = FontProperties(fname=font_path)

# Create a new figure with improved styling
fig, ax = plt.subplots(figsize=(14, 8))

# Use different line styles and markers for better distinction
line_styles = ['-', '--', '-.', ':']
markers = ['o', 's', '^', 'v', 'D', 'p', '*', 'h', 'x', '+']

for i, (industry, group) in enumerate(combined_df.groupby('industry')):
    # Cycle through different line styles and markers
    style = line_styles[i % len(line_styles)]
    marker = markers[i % len(markers)]
    ax.plot(group['trade_date'], group['total_mv'], 
            label=industry, 
            linestyle=style,
            marker=marker,
            markersize=4,
            linewidth=1.5)

# Improve plot formatting
ax.set_xlabel('交易日期', fontproperties=font_prop, fontsize=12)
ax.set_ylabel('总市值 (万元)', fontproperties=font_prop, fontsize=12)
ax.set_title('申万行业指数总市值趋势对比', fontproperties=font_prop, fontsize=14)
ax.grid(True, linestyle='--', alpha=0.6)
ax.legend(prop=font_prop, bbox_to_anchor=(1.05, 1), loc='upper left')

# Format x-axis
plt.xticks(rotation=45, ha='right')
plt.tight_layout()

# Save the enhanced plot
enhanced_output_path = os.path.join(output_dir, '申万行业指数总市值趋势_增强版.png')
fig.savefig(enhanced_output_path, bbox_inches='tight', dpi=300)
plt.close()

print(f"Enhanced plot saved successfully to {enhanced_output_path}")


Enhanced plot saved successfully to /Users/ash/Desktop/毕业/writer/output/AI研报/image/申万行业指数总市值趋势_增强版.png


In [4]:
# Verify the saved plots in the output directory
import os
output_dir = '/Users/ash/Desktop/毕业/writer/output/AI研报/image'

# List all files in the output directory that contain '申万行业指数'
saved_files = [f for f in os.listdir(output_dir) if '申万行业指数' in f and f.endswith('.png')]

print("Saved plot files containing '申万行业指数':")
for file in saved_files:
    print(f"- {file} (full path: {os.path.join(output_dir, file)})")
    
# If you need to save additional versions with different parameters, you can modify this:
if saved_files:  # If files exist, we can create another version
    fig, ax = plt.subplots(figsize=(14, 8))
    for i, (industry, group) in enumerate(combined_df.groupby('industry')):
        ax.plot(group['trade_date'], group['total_mv'], label=industry)
    ax.set_xlabel('交易日期', fontproperties=font_prop)
    ax.set_ylabel('总市值 (万元)', fontproperties=font_prop)
    ax.set_title('申万行业指数总市值趋势-简洁版', fontproperties=font_prop)
    ax.legend(prop=font_prop, bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xticks(rotation=45)
    plt.tight_layout()
    simple_output_path = os.path.join(output_dir, '申万行业指数总市值趋势_简洁版.png')
    fig.savefig(simple_output_path, bbox_inches='tight', dpi=300)
    plt.close()
    print(f"\nAdditional simple version saved to: {simple_output_path}")


Saved plot files containing '申万行业指数':
- 申万行业指数总市值趋势.png (full path: /Users/ash/Desktop/毕业/writer/output/AI研报/image/申万行业指数总市值趋势.png)
- 申万行业指数总市值趋势_增强版.png (full path: /Users/ash/Desktop/毕业/writer/output/AI研报/image/申万行业指数总市值趋势_增强版.png)



Additional simple version saved to: /Users/ash/Desktop/毕业/writer/output/AI研报/image/申万行业指数总市值趋势_简洁版.png
