Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for Version 1.0.0 release #71

Merged
merged 1 commit into from
Sep 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
66 changes: 49 additions & 17 deletions assembler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
# I M P O R T S ###############################################################

import argparse
import sys

from cocoasm.exceptions import TranslationError, ParseError
from cocoasm.program import Program
from cocoasm.virtualfiles.virtualfile import CoCoFile
from cocoasm.virtualfiles.binary import BinaryFile
from cocoasm.virtualfiles.cassette import CassetteFile
from cocoasm.virtualfiles.virtual_file import VirtualFileType, VirtualFile
from cocoasm.virtualfiles.source_file import SourceFile, SourceFileType
from cocoasm.virtualfiles.coco_file import CoCoFile

# F U N C T I O N S ###########################################################

Expand Down Expand Up @@ -56,14 +58,34 @@ def parse_arguments():
return parser.parse_args()


def throw_error(error):
"""
Prints out an error message.

:param error: the error message to throw
"""
print(error.value)
print("{}".format(str(error.statement)))
sys.exit(1)


def main(args):
"""
Runs the assembler with the specified arguments.

:param args: the command-line arguments
"""
program = Program(width=args.width)
program.process(args.filename)
source_file = SourceFile(args.filename)
source_file.read_file()
program = Program()

try:
program.process(source_file.get_buffer())
except TranslationError as error:
throw_error(error)
except ParseError as error:
throw_error(error)

coco_file = CoCoFile(
name=program.name or args.name,
load_addr=program.origin,
Expand All @@ -72,18 +94,25 @@ def main(args):
)

if args.symbols:
program.print_symbol_table()
print("-- Symbol Table --")
for symbol in program.get_symbol_table():
print(symbol)

if args.print:
program.print_statements()
print("-- Assembled Statements --")
for statement in program.get_statements():
print(statement)

if args.bin_file:
try:
binary_file = BinaryFile()
binary_file.open_host_file_for_write(args.bin_file, append=args.append)
binary_file.save_to_host_file(coco_file)
binary_file.close_host_file()
except ValueError as error:
virtual_file = VirtualFile(
SourceFile(args.bin_file, file_type=SourceFileType.BINARY),
VirtualFileType.BINARY
)
virtual_file.open_virtual_file()
virtual_file.add_coco_file(coco_file)
virtual_file.save_virtual_file(append_mode=args.append)
except Exception as error:
print("Unable to save binary file:")
print(error)

Expand All @@ -92,11 +121,14 @@ def main(args):
print("No name for the program specified, not creating cassette file")
return
try:
cas_file = CassetteFile()
cas_file.open_host_file_for_write(args.cas_file, append=args.append)
cas_file.save_to_host_file(coco_file)
cas_file.close_host_file()
except ValueError as error:
virtual_file = VirtualFile(
SourceFile(args.cas_file, file_type=SourceFileType.BINARY),
VirtualFileType.CASSETTE
)
virtual_file.open_virtual_file()
virtual_file.add_coco_file(coco_file)
virtual_file.save_virtual_file(append_mode=args.append)
except Exception as error:
print("Unable to save cassette file:")
print(error)

Expand Down
24 changes: 11 additions & 13 deletions cocoasm/operands.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,7 @@ def translate(self):
if registers[1] not in REGISTERS:
raise OperandTypeError("[{}] unknown register".format(registers[1]))

post_byte |= 0x00 if registers[0] == "D" else 0x00
post_byte |= 0x00 if registers[1] == "D" else 0x00
# Implicit is that a value of 0x00 means register D

post_byte |= 0x10 if registers[0] == "X" else 0x00
post_byte |= 0x01 if registers[1] == "X" else 0x00
Expand Down Expand Up @@ -471,7 +470,7 @@ def translate(self):
)
size = self.instruction.mode.ind_sz

if not type(self.value) == str and self.value.is_address():
if type(self.value) != str and self.value.is_address():
size += 2
return CodePackage(
op_code=NumericValue(self.instruction.mode.ind),
Expand All @@ -481,7 +480,7 @@ def translate(self):
max_size=size,
)

if not type(self.value) == str and self.value.is_numeric():
if type(self.value) != str and self.value.is_numeric():
size += 2
return CodePackage(
op_code=NumericValue(self.instruction.mode.ind),
Expand All @@ -507,7 +506,7 @@ def translate(self):
if "S" in self.right:
raw_post_byte |= 0x60

if self.left == "" or (not type(self.left) == str and self.left.is_numeric() and self.left.int == 0):
if self.left == "" or (type(self.left) != str and self.left.is_numeric() and self.left.int == 0):
if "-" in self.right or "+" in self.right:
if self.right == "X+" or self.right == "Y+" or self.right == "U+" or self.right == "S+":
raise OperandTypeError("[{}] not allowed as an extended indirect value".format(self.right))
Expand Down Expand Up @@ -603,13 +602,12 @@ def __init__(self, operand_string, instruction):
self.right = self.value.right

def resolve_symbols(self, symbol_table):
if self.left != "":
if self.left not in ["A", "B", "D"]:
self.left = Value.create_from_str(self.left, self.instruction, default_mode_extended=False)
if self.left.is_symbol():
self.left = self.left.resolve(symbol_table)
if self.left.is_address_expression() or self.left.is_expression():
self.left = self.left.resolve(symbol_table)
if self.left != "" and self.left not in ["A", "B", "D"]:
self.left = Value.create_from_str(self.left, self.instruction, default_mode_extended=False)
if self.left.is_symbol():
self.left = self.left.resolve(symbol_table)
if self.left.is_address_expression() or self.left.is_expression():
self.left = self.left.resolve(symbol_table)
return self

def translate(self):
Expand All @@ -634,7 +632,7 @@ def translate(self):
if "S" in self.right:
raw_post_byte |= 0x60

if self.left == "" or (not type(self.left) == str and self.left.is_numeric() and self.left.int == 0):
if self.left == "" or (type(self.left) != str and self.left.is_numeric() and self.left.int == 0):
raw_post_byte |= 0x80
if "-" in self.right or "+" in self.right:
if "+" in self.right:
Expand Down
154 changes: 62 additions & 92 deletions cocoasm/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
"""
# I M P O R T S ###############################################################

import sys

from cocoasm.exceptions import TranslationError, ParseError, ValueTypeError
from cocoasm.exceptions import TranslationError, ValueTypeError
from cocoasm.statement import Statement
from cocoasm.values import AddressValue, NoneValue
from cocoasm.virtualfiles.source_file import SourceFile

# C L A S S E S ###############################################################

Expand All @@ -21,59 +20,38 @@ class Program(object):
contains a list of statements. Additionally, a Program keeps track of all
the user-defined symbols in the program.
"""
def __init__(self, width=100):
def __init__(self):
self.symbol_table = dict()
self.statements = []
self.address = 0x0
self.origin = NoneValue()
self.name = None
self.width = width

def process(self, filename):
def process(self, source_file):
"""
Processes a filename for assembly.

:param filename: the name of the file to process
:param source_file: the source file to process
"""
try:
self.parse(filename)
self.translate_statements()
except TranslationError as error:
self.throw_error(error)
except ParseError as error:
self.throw_error(error)
self.statements = self.parse(source_file)
self.translate_statements()

def parse(self, filename):
@classmethod
def parse(cls, contents):
"""
Parses a single file and saves the set of statements.

:param filename: the name of the file to process
"""
self.statements = self.parse_file(filename)

def parse_file(self, filename):
"""
Parses all of the lines in a file, and transforms each line into
a Statement. Returns a list of all the statements in the file.

:param filename: the name of the file to parse
:param contents: a list of strings, each string represents one line of assembly
"""
statements = []
if not filename:
return statements

try:
with open(filename) as infile:
for line in infile:
statement = Statement(line)
if not statement.is_empty and not statement.is_comment_only:
statements.append(statement)
except ParseError as error:
self.throw_error(error)

for line in contents:
statement = Statement(line)
if not statement.is_empty and not statement.is_comment_only:
statements.append(statement)
return statements

def process_mnemonics(self, statements):
@classmethod
def process_mnemonics(cls, statements):
"""
Given a list of statements, processes the mnemonics on each statement, and
assigns each statement an Instruction object. If the statement is the
Expand All @@ -85,8 +63,14 @@ def process_mnemonics(self, statements):
"""
processed_statements = []
for statement in statements:
include = self.process_mnemonics(self.parse_file(statement.get_include_filename()))
processed_statements.extend(include if include else [statement])
include_filename = statement.get_include_filename()
if include_filename:
include_source = SourceFile(include_filename)
include_source.read_file()
include = cls.process_mnemonics(cls.parse(include_source.get_buffer()))
processed_statements.extend(include)
else:
processed_statements.extend([statement])
return processed_statements

def save_symbol(self, index, statement):
Expand All @@ -112,46 +96,40 @@ def translate_statements(self):
Translates all the parsed statements into their respective
opcodes.
"""
try:
self.statements = self.process_mnemonics(self.statements)
for index, statement in enumerate(self.statements):
self.save_symbol(index, statement)
self.statements = self.process_mnemonics(self.statements)
for index, statement in enumerate(self.statements):
self.save_symbol(index, statement)

for index, statement in enumerate(self.statements):
try:
statement.resolve_symbols(self.symbol_table)
except ValueTypeError as error:
raise TranslationError(str(error), statement)
for index, statement in enumerate(self.statements):
statement.resolve_symbols(self.symbol_table)

for index, statement in enumerate(self.statements):
statement.translate()

while not self.all_sizes_fixed():
for index, statement in enumerate(self.statements):
statement.translate()
if not statement.fixed_size:
statement.determine_pcr_relative_sizes(self.statements, index)

while not self.all_sizes_fixed():
for index, statement in enumerate(self.statements):
if not statement.fixed_size:
statement.determine_pcr_relative_sizes(self.statements, index)
address = 0
for index, statement in enumerate(self.statements):
address = statement.set_address(address)
address += statement.code_pkg.size

address = 0
for index, statement in enumerate(self.statements):
address = statement.set_address(address)
address += statement.code_pkg.size
for index, statement in enumerate(self.statements):
statement.fix_addresses(self.statements, index)

for index, statement in enumerate(self.statements):
statement.fix_addresses(self.statements, index)

# Update the symbol table with the proper addresses
for symbol, value in self.symbol_table.items():
if value.is_address():
self.symbol_table[symbol] = self.statements[value.int].code_pkg.address

# Find the origin and name of the project
for statement in self.statements:
if statement.instruction.is_origin:
self.origin = statement.code_pkg.address
if statement.instruction.is_name:
self.name = statement.operand.operand_string
except TranslationError as error:
self.throw_error(error)
# Update the symbol table with the proper addresses
for symbol, value in self.symbol_table.items():
if value.is_address():
self.symbol_table[symbol] = self.statements[value.int].code_pkg.address

# Find the origin and name of the project
for statement in self.statements:
if statement.instruction.is_origin:
self.origin = statement.code_pkg.address
if statement.instruction.is_name:
self.name = statement.operand.operand_string

def get_binary_array(self):
"""
Expand Down Expand Up @@ -187,30 +165,22 @@ def all_sizes_fixed(self):
return False
return True

def print_symbol_table(self):
def get_symbol_table(self):
"""
Prints out the symbol table and any values contained within it.
Returns a list of strings. Each string contains one entry from the symbol table.
"""
print("-- Symbol Table --")
lines = []
for symbol, value in self.symbol_table.items():
print("${} {}".format(value.hex().ljust(4, ' '), symbol))
lines.append("${} {}".format(value.hex().ljust(4, ' '), symbol))
return lines

def print_statements(self):
def get_statements(self):
"""
Prints out the assembled statements.
Returns a list of strings. Each string represents one assembled statement
"""
print("-- Assembled Statements --")
lines = []
for index, statement in enumerate(self.statements):
print("{}".format(str(statement).ljust(self.width)))

def throw_error(self, error):
"""
Prints out an error message.

:param error: the error message to throw
"""
print(error.value)
print("{}".format(str(error.statement).ljust(self.width)))
sys.exit(1)
lines.append("{}".format(str(statement)))
return lines

# E N D O F F I L E #######################################################