Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,27 @@ public class PythonCodeGenerator extends AbstractExternalGenerator {

private Map<String, PythonCodeGeneratorContext> contexts = null;

/**
* Optional override for the pyproject.toml project name.
* When null, the name is derived from the namespace as "python-&lt;first-segment&gt;".
*/
private String projectName = null;

public PythonCodeGenerator() {
super(PYTHON);
contexts = new HashMap<>();
}

/**
* Overrides the pyproject.toml project name. When not set (or set to null),
* the name is derived from the namespace as "python-&lt;first-segment&gt;".
*
* @param projectName the project name, or null for default behaviour
*/
public void setProjectName(String projectName) {
this.projectName = projectName;
}

@Override
public Map<String, ? extends CharSequence> beforeAllGenerate(ResourceSet set,
Collection<? extends RosettaModel> models, String version) {
Expand Down Expand Up @@ -180,7 +196,7 @@ private Map<String, CharSequence> processDAG(String nameSpace, PythonCodeGenerat
Set<String> enumImports = context.getEnumImports();

if (nameSpaceObjects != null && !nameSpaceObjects.isEmpty() && dependencyDAG != null && enumImports != null) {
result.put(PYPROJECT_TOML, PythonCodeGeneratorUtil.createPYProjectTomlFile(nameSpace, cleanVersion));
result.put(PYPROJECT_TOML, PythonCodeGeneratorUtil.createPYProjectTomlFile(nameSpace, cleanVersion, projectName));
PythonCodeWriter bundleWriter = new PythonCodeWriter();
TopologicalOrderIterator<String, DefaultEdge> topologicalOrderIterator = new TopologicalOrderIterator<>(
dependencyDAG);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,17 @@ public int execute(String[] args) {
.desc("Continue generation even if validation errors occur").build();
Option failOnWarningsOpt = Option.builder("w").longOpt("fail-on-warnings")
.desc("Treat validation warnings as errors").build();
Option projectNameOpt = Option.builder("n").longOpt("project-name").argName("projectName")
.desc("Override the pyproject.toml project name (default: python-<first-namespace-segment>)")
.hasArg().build();

options.addOption(help);
options.addOption(srcDirOpt);
options.addOption(srcFileOpt);
options.addOption(tgtDirOpt);
options.addOption(allowErrorsOpt);
options.addOption(failOnWarningsOpt);
options.addOption(projectNameOpt);

CommandLineParser parser = new DefaultParser();
try {
Expand All @@ -114,13 +118,14 @@ public int execute(String[] args) {
String tgtDir = cmd.getOptionValue("t", "./python");
boolean allowErrors = cmd.hasOption("e");
boolean failOnWarnings = cmd.hasOption("w");
String projectName = cmd.getOptionValue("n");

if (cmd.hasOption("s")) {
String srcDir = cmd.getOptionValue("s");
return translateFromSourceDir(srcDir, tgtDir, allowErrors, failOnWarnings);
return translateFromSourceDir(srcDir, tgtDir, allowErrors, failOnWarnings, projectName);
} else if (cmd.hasOption("f")) {
String srcFile = cmd.getOptionValue("f");
return translateFromSourceFile(srcFile, tgtDir, allowErrors, failOnWarnings);
return translateFromSourceFile(srcFile, tgtDir, allowErrors, failOnWarnings, projectName);
} else {
System.err.println("Either a source directory (-s) or source file (-f) must be specified.");
printUsage(options);
Expand All @@ -139,7 +144,11 @@ private static void printUsage(Options options) {
}

protected int translateFromSourceDir(String srcDir, String tgtDir, boolean allowErrors, boolean failOnWarnings) {
// Find all .rosetta files in a directory
return translateFromSourceDir(srcDir, tgtDir, allowErrors, failOnWarnings, null);
}

protected int translateFromSourceDir(String srcDir, String tgtDir, boolean allowErrors, boolean failOnWarnings,
String projectName) {
Path srcDirPath = Paths.get(srcDir);
if (!Files.exists(srcDirPath)) {
LOGGER.error("Source directory does not exist: {}", srcDir);
Expand All @@ -154,14 +163,19 @@ protected int translateFromSourceDir(String srcDir, String tgtDir, boolean allow
.filter(Files::isRegularFile)
.filter(f -> f.getFileName().toString().endsWith(".rosetta"))
.collect(Collectors.toList());
return processRosettaFiles(rosettaFiles, tgtDir, allowErrors, failOnWarnings);
return processRosettaFiles(rosettaFiles, tgtDir, allowErrors, failOnWarnings, projectName);
} catch (IOException e) {
LOGGER.error("Failed to process source directory: {}", srcDir, e);
return 1;
}
}

protected int translateFromSourceFile(String srcFile, String tgtDir, boolean allowErrors, boolean failOnWarnings) {
return translateFromSourceFile(srcFile, tgtDir, allowErrors, failOnWarnings, null);
}

protected int translateFromSourceFile(String srcFile, String tgtDir, boolean allowErrors, boolean failOnWarnings,
String projectName) {
Path srcFilePath = Paths.get(srcFile);
if (!Files.exists(srcFilePath)) {
LOGGER.error("Source file does not exist: {}", srcFile);
Expand All @@ -176,7 +190,7 @@ protected int translateFromSourceFile(String srcFile, String tgtDir, boolean all
return 1;
}
List<Path> rosettaFiles = List.of(srcFilePath);
return processRosettaFiles(rosettaFiles, tgtDir, allowErrors, failOnWarnings);
return processRosettaFiles(rosettaFiles, tgtDir, allowErrors, failOnWarnings, projectName);
}

protected IResourceValidator getValidator(Injector injector) {
Expand All @@ -186,6 +200,11 @@ protected IResourceValidator getValidator(Injector injector) {
// Common processing function
protected int processRosettaFiles(List<Path> rosettaFiles, String tgtDir, boolean allowErrors,
boolean failOnWarnings) {
return processRosettaFiles(rosettaFiles, tgtDir, allowErrors, failOnWarnings, null);
}

protected int processRosettaFiles(List<Path> rosettaFiles, String tgtDir, boolean allowErrors,
boolean failOnWarnings, String projectName) {
LOGGER.info("Processing {} .rosetta files, writing to: {}", rosettaFiles.size(), tgtDir);

if (rosettaFiles.isEmpty()) {
Expand All @@ -204,6 +223,7 @@ protected int processRosettaFiles(List<Path> rosettaFiles, String tgtDir, boolea
.forEach(resources::add);

PythonCodeGenerator pythonCodeGenerator = injector.getInstance(PythonCodeGenerator.class);
pythonCodeGenerator.setProjectName(projectName);
PythonModelLoader modelLoader = injector.getInstance(PythonModelLoader.class);

List<RosettaModel> models = modelLoader.getRosettaModels(resources);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,28 @@ public static String getNamespace(RosettaModel rm) {
}

public static String createPYProjectTomlFile(String namespace, String version) {
return createPYProjectTomlFile(namespace, version, null);
}

public static String createPYProjectTomlFile(String namespace, String version, String projectName) {
String name = (projectName != null && !projectName.isBlank())
? projectName
: "python-" + namespace.split("\\.")[0];
return """
[build-system]
requires = ["setuptools>=62.0"]
build-backend = "setuptools.build_meta"

[project]
name = "python-%s"
name = "%s"
version = "%s"
requires-python = ">= 3.11"
dependencies = [
"pydantic>=2.10.3",
"rune.runtime>=1.0.0,<2.0.0"
]
[tool.setuptools.packages.find]
where = ["src"]""".formatted(namespace, version).stripIndent();
where = ["src"]""".formatted(name, version).stripIndent();
}

public static String cleanVersion(String version) {
Expand Down